async def run(pc): session = ClientSession() async with session.ws_connect("ws://39.102.116.49:8080") as ws: async for msg in ws: if msg.type == WSMsgType.TEXT: data = json.loads(msg.data) if data["type"] == "offerOrAnswer": await pc.setRemoteDescription( object_from_string(json.dumps(data["msg"]))) if data["msg"]["type"] == "offer": pc.addTrack(FlagVideoStreamTrack()) await pc.setLocalDescription(await pc.createAnswer()) await ws.send_str( json.dumps({ "type": "offerOrAnswer", "msg": json.loads( object_to_string(pc.localDescription)), })) elif data["type"] == "candidate": try: await pc.addIceCandidate( object_from_string(json.dumps(data["msg"]))) except: pass
class WsTransport(BaseOutboundTransport): """Websockets outbound transport class.""" schemes = ("ws", "wss") def __init__(self) -> None: """Initialize an `WsTransport` instance.""" super(WsTransport, self).__init__() self.logger = logging.getLogger(__name__) async def start(self): """Start the outbound transport.""" self.client_session = ClientSession(cookie_jar=DummyCookieJar()) return self async def stop(self): """Stop the outbound transport.""" await self.client_session.close() self.client_session = None async def handle_message(self, message: OutboundMessage): """ Handle message from queue. Args: message: `OutboundMessage` to send over transport implementation """ # aiohttp should automatically handle websocket sessions async with self.client_session.ws_connect(message.endpoint) as ws: if isinstance(message.payload, bytes): await ws.send_bytes(message.payload) else: await ws.send_str(message.payload)
async def ws_backend(cb, session: aiohttp.ClientSession = None): ws_logger = logging.getLogger("rcmproxy.run.ws_backend") if session is None: session = aiohttp.ClientSession() while True: try: async with session.ws_connect( f"http://{UPSTREAM_IP}:{UPSTREAM_PORT_WS}/") as ws: async for msg in ws: msg: aiohttp.WSMessage if msg.type == aiohttp.WSMsgType.TEXT: ws_logger.debug( "Received WSMessage of type WSMsgType.TEXT") await cb(msg.data, type_="StreamingData") else: ws_logger.warning(f"Unexpexted WSMsgType: {msg.type}") ws_logger.debug(msg) except aiohttp.ClientError as e: ws_logger.debug(e, exc_info=True) ws_logger.warning(e) await asyncio.sleep(0.1)
async def game_client(host, gameid, player): session = ClientSession() if ":" in host: host = "[" + host + "]" # IPv6 literal async with session.ws_connect("http://%s:8888/ws" % host) as ws: stats[0] += 1 ws.send_json({ "type": "login", "data": { "room": gameid, "name": str(player) } }) async def make_moves(): # Stagger the requests a bit tm = (time.time() - SECONDS_BETWEEN_MOVES + SECONDS_BETWEEN_MOVES // PLAYERS_PER_GAME * player + random.randrange(SECONDS_BETWEEN_MOVES // PLAYERS_PER_GAME)) while ws: tm += SECONDS_BETWEEN_MOVES delay = tm - time.time() if delay > 0: await asyncio.sleep(delay) if not ws: break stats[1] += 1 ws.send_str(move_data) asyncio.ensure_future(make_moves()) async for msg in ws: if msg.type == WSMsgType.TEXT: stats[2] += len(msg.data) ws = None
async def ws_client(): session = ClientSession() async with session.ws_connect('http://0.0.0.0:8080/ws') as ws: await promt(ws) async for msg in ws: print('Receive from server: ', msg.data) await promt(ws)
async def call_ws(loop, data, expected = None): session = ClientSession(loop = loop) async with session.ws_connect('http://localhost:5000/bank') as websocket: await websocket.send_json(data) response = await websocket.receive_json() print ('resp', response, type(response)) if expected: print ('exp', expected, type(expected)) assert response == expected await session.close()
async def handle_worker(): global ws session = ClientSession() async with session.ws_connect( f"ws://{env['WORKER_MANAGER_HOST']}:6060/workers") as ws: await ws.send_json({"t": "identify", "d": None}) message: WSMessage async for message in ws: if message.type == WSMsgType.TEXT: data = message.json(loads=loads) handler = handlers.get(data["t"], None) if handler is None: continue client.loop.create_task(handler(data["d"]))
async def ws_run_forever( url: StrOrURL, session: aiohttp.ClientSession, event: asyncio.Event, *, send_str: Optional[str] = None, send_json: Optional[Any] = None, hdlr_str=None, hdlr_json=None, **kwargs: Any, ) -> None: iscorofunc_str = asyncio.iscoroutinefunction(hdlr_str) iscorofunc_json = asyncio.iscoroutinefunction(hdlr_json) while not session.closed: separator = asyncio.create_task(asyncio.sleep(60.0)) try: async with session.ws_connect(url, **kwargs) as ws: event.set() if send_str is not None: await ws.send_str(send_str) if send_json is not None: await ws.send_json(send_json) async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: if hdlr_str is not None: try: if iscorofunc_str: await hdlr_str(msg.data, ws) else: hdlr_str(msg.data, ws) except Exception as e: logger.error(repr(e)) if hdlr_json is not None: try: data = msg.json() except json.decoder.JSONDecodeError: pass else: try: if iscorofunc_json: await hdlr_json(data, ws) else: hdlr_json(data, ws) except Exception as e: logger.error(repr(e)) elif msg.type == aiohttp.WSMsgType.ERROR: break except aiohttp.WSServerHandshakeError as e: logger.warning(repr(e)) await separator
async def main(): session = ClientSession() async with session.ws_connect('http://0.0.0.0:8080/ws') as ws: await ws.send_str('Hello server! It is WS Client!') async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: print(msg.data) break elif msg.type == aiohttp.WSMsgType.ERROR: break
class ReceiverService(object): def __init__(self, billing_id: str, client_id: str, client_secret: str, config: ClientConfig): self._logger = config.get_logger(__name__) self.auth_service = AuthService(purpose=self.__class__.__name__, billing_id=billing_id, client_id=client_id, client_secret=client_secret, config=config) self._config = config self._client = None self._running = False self._session = ClientSession() async def start_timer(self): await self.auth_service.start() async def start(self, as_json: bool, consumer: Callable[[Any], Any]): self._running = True await self.auth_service.start() uri = self._config.egress_uri + ("?asJson=true" if as_json else "") while self._running: async with self._session.ws_connect( uri, headers={ 'Authorization': f'Bearer {self.auth_service.get_access_token()}', 'Strm-Driver-Version': self._config.version.brief_string(), 'Strm-Driver-Build': self._config.version.release_string() }) as ws: async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: await consumer(msg.data) elif msg.type == aiohttp.WSMsgType.CLOSED: self._logger.debug("Websocket connection closed") break elif msg.type == aiohttp.WSMsgType.ERROR: self._logger.debug( "Error upon receiving data from websocket") break def close(self): self._running = False
async def connect_to_service_provider(task): global APP if 'active' in task and task['active']: return task['active'] = True desc = task['desc'] ws_id = '' try: session = ClientSession() async with session.ws_connect(task['url'], heartbeat=30) as ws: ws_id = ws.id = str(uuid.uuid4()) async for msg in ws: if msg.type == WSMsgType.TEXT: if msg.data == 'close': await ws.close() else: await handle_service_provider_message( APP, msg.data, ws) elif msg.type == WSMsgType.CLOSE: log.server_logger.info('%s connection closed normally', desc) break elif msg.type == WSMsgType.ERROR: log.server_logger.info( '%s connection closed with error %s', desc, ws.exception()) break log.server_logger.info('%s connection closed normally', desc) except Exception as e: log.server_logger.exception('%s connection closed with exception: %s', desc, e) except: log.server_logger.exception('%s uncaught exception!', desc) finally: if not session.closed: try: await session.close() except: pass if ws_id: await destroy_service_provider(APP, ws_id) task['active'] = False
class WsTransport(BaseOutboundTransport): """Websockets outbound transport class.""" schemes = ("ws", "wss") def __init__(self, **kwargs) -> None: """Initialize an `WsTransport` instance.""" super().__init__(**kwargs) self.logger = logging.getLogger(__name__) async def start(self): """Start the outbound transport.""" self.client_session = ClientSession(cookie_jar=DummyCookieJar(), trust_env=True) return self async def stop(self): """Stop the outbound transport.""" await self.client_session.close() self.client_session = None async def handle_message( self, profile: Profile, payload: Union[str, bytes], endpoint: str, metadata: dict = None, api_key: str = None, ): """ Handle message from queue. Args: profile: the profile that produced the message payload: message payload in string or byte format endpoint: URI endpoint for delivery metadata: Additional metadata associated with the payload """ # aiohttp should automatically handle websocket sessions async with self.client_session.ws_connect(endpoint, headers=metadata) as ws: if isinstance(payload, bytes): await ws.send_bytes(payload) else: await ws.send_str(payload)
class WebsocketDanmuService: def __init__(self, ws_address: str, payloads: list, hb: bytes, interval: int, callback: Callable[[WSMessage], Optional[Awaitable]]): self.ws: Optional[ClientWebSocketResponse] = None self.session = ClientSession() self.ws_address = ws_address self.payloads = payloads self.hb = hb self.interval = interval self.heartbeat_task: Optional['Future'] = None self.cb = callback async def connect(self): async with self.session.ws_connect(self.ws_address) as ws: self.ws = ws await self.running() async def send_heartbeat_once(self): await self.ws.send_bytes(self.hb) async def send_heartbeat(self): while True: await asyncio.sleep(self.interval) await self.send_heartbeat_once() async def running(self): for payload in self.payloads: await self.ws.send_bytes(payload) self.heartbeat_task = asyncio.ensure_future(self.send_heartbeat()) async for msg in self.ws: msg: WSMessage if asyncio.iscoroutinefunction(self.cb): await self.cb(msg) else: loop = asyncio.get_event_loop() await loop.run_in_executor(None, self.cb, msg) async def stop(self): self.heartbeat_task.cancel() with suppress(asyncio.CancelledError): await self.heartbeat_task await self.ws.close() await self.session.close()
class WsTransport(BaseOutboundTransport): """Websockets outbound transport class.""" schemes = ("ws", "wss") def __init__(self, queue: BaseOutboundMessageQueue) -> None: """Initialize an `HttpTransport` instance.""" self.logger = logging.getLogger(__name__) self._queue = queue async def __aenter__(self): """Async context manager enter.""" self.client_session = ClientSession() return self async def __aexit__(self, *err): """Async context manager exit.""" await self.client_session.close() self.client_session = None self.logger.error(err) @property def queue(self): """Accessor for queue.""" return self._queue async def handle_message(self, message: OutboundMessage): """ Handle message from queue. Args: message: `OutboundMessage` to send over transport implementation """ try: # As an example, we can open a websocket channel, send a message, then # close the channel immediately. This is not optimal but it works. async with self.client_session.ws_connect(message.endpoint) as ws: if isinstance(message.payload, bytes): await ws.send_bytes(message.payload) else: await ws.send_str(message.payload) except Exception: # TODO: add retry logic self.logger.exception("Error handling outbound message")
async def _websocket_connect(self, endpoint: str, session: aiohttp.ClientSession) -> None: """ Helper method to create websocket connection with specified *endpoint* using the specified :class:`aiohttp.ClientSession`. Once connected, we initialise and start the GraphQL subscription; then wait for any incoming messages. Any message received via the websocket connection is cast into a :class:`GraphQLSubscriptionEvent` instance and dispatched for handling via :method:`handle`. :param endpoint: Endpoint to use when creating the websocket connection. :param session: Session to use when creating the websocket connection. """ async with session.ws_connect(endpoint) as ws: await ws.send_json(data=self.connection_init_request()) self.callbacks.register( GraphQLSubscriptionEventType.CONNECTION_ACK, SimpleTriggerCallback(function=ws.send_json, data=self.connection_start_request()), ) try: async for msg in ws: # type: aiohttp.WSMessage if msg.type != aiohttp.WSMsgType.TEXT: if msg.type == aiohttp.WSMsgType.ERROR: break continue event = GraphQLSubscriptionEvent( subscription_id=self.id, request=self.request, json=msg.json(), ) await self.handle(event=event) if self.is_stop_event(event): break except (asyncio.CancelledError, KeyboardInterrupt): await ws.send_json(data=self.connection_stop_request())
class WsTransport(BaseOutboundTransport): """Websockets outbound transport class.""" schemes = ("ws", "wss") def __init__(self) -> None: """Initialize an `WsTransport` instance.""" super(WsTransport, self).__init__() self.logger = logging.getLogger(__name__) async def start(self): """Start the outbound transport.""" self.client_session = ClientSession(cookie_jar=DummyCookieJar()) return self async def stop(self): """Stop the outbound transport.""" await self.client_session.close() self.client_session = None async def handle_message(self, context: InjectionContext, payload: Union[str, bytes], endpoint: str): """ Handle message from queue. Args: context: the context that produced the message payload: message payload in string or byte format endpoint: URI endpoint for delivery """ # aiohttp should automatically handle websocket sessions async with self.client_session.ws_connect(endpoint) as ws: if isinstance(payload, bytes): await ws.send_bytes(payload) else: await ws.send_str(payload)
class WsConn(Conn): # url 格式 ws://hostname:port/… 或者 wss://hostname:port/… def __init__( self, url: str, receive_timeout: Optional[float] = None, session: Optional[ClientSession] = None, ws_receive_timeout: Optional[float] = None, # 自动pingpong时候用的 ws_heartbeat: Optional[float] = None): # 自动pingpong时候用的 super().__init__(receive_timeout) result = urlparse(url) assert result.scheme == 'ws' or result.scheme == 'wss' self._url = url if session is None: self._is_sharing_session = False self._session = ClientSession() else: self._is_sharing_session = True self._session = session self._ws_receive_timeout = ws_receive_timeout self._ws_heartbeat = ws_heartbeat self._ws = None async def open(self) -> bool: try: self._ws = await asyncio.wait_for(self._session.ws_connect( self._url, receive_timeout=self._ws_receive_timeout, heartbeat=self._ws_heartbeat), timeout=3) except asyncio.TimeoutError: return False except Exception: return False return True async def close(self) -> bool: if self._ws is not None: await self._ws.close() return True async def clean(self): if not self._is_sharing_session: await self._session.close() async def send_bytes(self, bytes_data) -> bool: try: await self._ws.send_bytes(bytes_data) except asyncio.CancelledError: return False except Exception: return False return True async def read_bytes(self) -> Optional[bytes]: try: bytes_data = await asyncio.wait_for(self._ws.receive_bytes(), timeout=self._receive_timeout) except asyncio.TimeoutError: return None except Exception: return None return bytes_data async def read_json(self) -> Any: try: msg = await asyncio.wait_for(self._ws.receive(), timeout=self._receive_timeout) if msg.type == WSMsgType.TEXT: return json.loads(msg.data) elif msg.type == WSMsgType.BINARY: return json.loads(msg.data.decode('utf8')) except asyncio.TimeoutError: return None except Exception: return None return None
class TestAdminServer(AsyncTestCase): async def setUp(self): self.message_results = [] self.webhook_results = [] self.port = 0 self.connector = TCPConnector(limit=16, limit_per_host=4) session_args = { "cookie_jar": DummyCookieJar(), "connector": self.connector } self.client_session = ClientSession(cookie_jar=DummyCookieJar(), connector=self.connector) async def tearDown(self): if self.client_session: await self.client_session.close() self.client_session = None async def test_debug_middleware(self): with async_mock.patch.object(test_module, "LOGGER", async_mock.MagicMock()) as mock_logger: mock_logger.isEnabledFor = async_mock.MagicMock(return_value=True) mock_logger.debug = async_mock.MagicMock() request = async_mock.MagicMock( method="GET", path_qs="/hello/world?a=1&b=2", match_info={"match": "info"}, text=async_mock.CoroutineMock(return_value="abc123"), ) handler = async_mock.CoroutineMock() await test_module.debug_middleware(request, handler) mock_logger.isEnabledFor.assert_called_once() assert mock_logger.debug.call_count == 3 async def test_ready_middleware(self): with async_mock.patch.object(test_module, "LOGGER", async_mock.MagicMock()) as mock_logger: mock_logger.isEnabledFor = async_mock.MagicMock(return_value=True) mock_logger.debug = async_mock.MagicMock() mock_logger.info = async_mock.MagicMock() mock_logger.error = async_mock.MagicMock() request = async_mock.MagicMock( rel_url="/", app=async_mock.MagicMock(_state={"ready": False})) handler = async_mock.CoroutineMock(return_value="OK") with self.assertRaises(test_module.web.HTTPServiceUnavailable): await test_module.ready_middleware(request, handler) request.app._state["ready"] = True assert await test_module.ready_middleware(request, handler) == "OK" request.app._state["ready"] = True handler = async_mock.CoroutineMock( side_effect=test_module.LedgerConfigError("Bad config")) with self.assertRaises(test_module.LedgerConfigError): await test_module.ready_middleware(request, handler) request.app._state["ready"] = True handler = async_mock.CoroutineMock( side_effect=test_module.web.HTTPFound(location="/api/doc")) with self.assertRaises(test_module.web.HTTPFound): await test_module.ready_middleware(request, handler) request.app._state["ready"] = True handler = async_mock.CoroutineMock( side_effect=test_module.asyncio.CancelledError("Cancelled")) with self.assertRaises(test_module.asyncio.CancelledError): await test_module.ready_middleware(request, handler) request.app._state["ready"] = True handler = async_mock.CoroutineMock( side_effect=KeyError("No such thing")) with self.assertRaises(KeyError): await test_module.ready_middleware(request, handler) def get_admin_server(self, settings: dict = None, context: InjectionContext = None) -> AdminServer: if not context: context = InjectionContext() if settings: context.update_settings(settings) # middleware is task queue xor collector: cover both over test suite task_queue = (settings or {}).pop("task_queue", None) plugin_registry = async_mock.MagicMock(test_module.PluginRegistry, autospec=True) plugin_registry.post_process_routes = async_mock.MagicMock() context.injector.bind_instance(test_module.PluginRegistry, plugin_registry) collector = Collector() context.injector.bind_instance(test_module.Collector, collector) profile = InMemoryProfile.test_profile() self.port = unused_port() return AdminServer( "0.0.0.0", self.port, context, profile, self.outbound_message_router, self.webhook_router, conductor_stop=async_mock.CoroutineMock(), task_queue=TaskQueue(max_active=4) if task_queue else None, conductor_stats=(None if task_queue else async_mock.CoroutineMock( return_value={"a": 1})), ) async def outbound_message_router(self, *args): self.message_results.append(args) def webhook_router(self, *args): self.webhook_results.append(args) async def test_start_stop(self): with self.assertRaises(AssertionError): await self.get_admin_server().start() settings = {"admin.admin_insecure_mode": False} with self.assertRaises(AssertionError): await self.get_admin_server(settings).start() settings = { "admin.admin_insecure_mode": True, "admin.admin_api_key": "test-api-key", } with self.assertRaises(AssertionError): await self.get_admin_server(settings).start() settings = { "admin.admin_insecure_mode": False, "admin.admin_client_max_request_size": 4, "admin.admin_api_key": "test-api-key", } server = self.get_admin_server(settings) await server.start() assert server.app._client_max_size == 4 * 1024 * 1024 with async_mock.patch.object(server, "websocket_queues", async_mock.MagicMock()) as mock_wsq: mock_wsq.values = async_mock.MagicMock(return_value=[ async_mock.MagicMock(stop=async_mock.MagicMock()) ]) await server.stop() with async_mock.patch.object(web.TCPSite, "start", async_mock.CoroutineMock()) as mock_start: mock_start.side_effect = OSError("Failure to launch") with self.assertRaises(AdminSetupError): await self.get_admin_server(settings).start() async def test_import_routes(self): # this test just imports all default admin routes # for routes with associated tests, this shouldn't make a difference in coverage context = InjectionContext() context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) await DefaultContextBuilder().load_plugins(context) server = self.get_admin_server({"admin.admin_insecure_mode": True}, context) app = await server.make_application() async def test_import_routes_multitenant_middleware(self): # imports all default admin routes context = InjectionContext() context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) profile = InMemoryProfile.test_profile() context.injector.bind_instance( test_module.MultitenantManager, test_module.MultitenantManager(profile), ) await DefaultContextBuilder().load_plugins(context) server = self.get_admin_server( { "admin.admin_insecure_mode": False, "admin.admin_api_key": "test-api-key", }, context, ) # cover multitenancy start code app = await server.make_application() app["swagger_dict"] = {} await server.on_startup(app) # multitenant authz [mt_authz_middle] = [ m for m in app.middlewares if ".check_multitenant_authorization" in str(m) ] mock_request = async_mock.MagicMock( method="GET", headers={"Authorization": "Bearer ..."}, path="/multitenancy/etc", text=async_mock.CoroutineMock(return_value="abc123"), ) with self.assertRaises(test_module.web.HTTPUnauthorized): await mt_authz_middle(mock_request, None) mock_request = async_mock.MagicMock( method="GET", headers={}, path="/protected/non-multitenancy/non-server", text=async_mock.CoroutineMock(return_value="abc123"), ) with self.assertRaises(test_module.web.HTTPUnauthorized): await mt_authz_middle(mock_request, None) mock_request = async_mock.MagicMock( method="GET", headers={"Authorization": "Bearer ..."}, path="/protected/non-multitenancy/non-server", text=async_mock.CoroutineMock(return_value="abc123"), ) mock_handler = async_mock.CoroutineMock() await mt_authz_middle(mock_request, mock_handler) assert mock_handler.called_once_with(mock_request) # multitenant setup context exception paths [setup_ctx_middle ] = [m for m in app.middlewares if ".setup_context" in str(m)] mock_request = async_mock.MagicMock( method="GET", headers={"Authorization": "Non-bearer ..."}, path="/protected/non-multitenancy/non-server", text=async_mock.CoroutineMock(return_value="abc123"), ) with self.assertRaises(test_module.web.HTTPUnauthorized): await setup_ctx_middle(mock_request, None) mock_request = async_mock.MagicMock( method="GET", headers={"Authorization": "Bearer ..."}, path="/protected/non-multitenancy/non-server", text=async_mock.CoroutineMock(return_value="abc123"), ) with async_mock.patch.object( server.multitenant_manager, "get_profile_for_token", async_mock.CoroutineMock(), ) as mock_get_profile: mock_get_profile.side_effect = [ test_module.MultitenantManagerError("corrupt token"), test_module.StorageNotFoundError("out of memory"), ] for i in range(2): with self.assertRaises(test_module.web.HTTPUnauthorized): await setup_ctx_middle(mock_request, None) async def test_register_external_plugin_x(self): context = InjectionContext() context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) with self.assertRaises(ValueError): builder = DefaultContextBuilder( settings={"external_plugins": "aries_cloudagent.nosuchmodule"}) await builder.load_plugins(context) async def test_visit_insecure_mode(self): settings = {"admin.admin_insecure_mode": True, "task_queue": True} server = self.get_admin_server(settings) await server.start() async with self.client_session.post( f"http://127.0.0.1:{self.port}/status/reset", headers={}) as response: assert response.status == 200 async with self.client_session.ws_connect( f"http://127.0.0.1:{self.port}/ws") as ws: result = await ws.receive_json() assert result["topic"] == "settings" for path in ( "", "plugins", "status", "status/live", "status/ready", "shutdown", # mock conductor has magic-mock stop() ): async with self.client_session.get( f"http://127.0.0.1:{self.port}/{path}", headers={}) as response: assert response.status == 200 await server.stop() async def test_visit_secure_mode(self): settings = { "admin.admin_insecure_mode": False, "admin.admin_api_key": "test-api-key", } server = self.get_admin_server(settings) await server.start() async with self.client_session.get( f"http://127.0.0.1:{self.port}/status", headers={"x-api-key": "wrong-key"}) as response: assert response.status == 401 async with self.client_session.get( f"http://127.0.0.1:{self.port}/status", headers={"x-api-key": "test-api-key"}, ) as response: assert response.status == 200 async with self.client_session.ws_connect( f"http://127.0.0.1:{self.port}/ws", headers={"x-api-key": "test-api-key"}) as ws: result = await ws.receive_json() assert result["topic"] == "settings" await server.stop() async def test_query_config(self): settings = { "admin.admin_insecure_mode": False, "admin.admin_api_key": "test-api-key", "admin.webhook_urls": ["localhost:8123/abc#secret", "localhost:8123/def"], "multitenant.jwt_secret": "abc123", "wallet.key": "abc123", "wallet.rekey": "def456", "wallet.seed": "00000000000000000000000000000000", "wallet.storage.creds": "secret", } server = self.get_admin_server(settings) await server.start() async with self.client_session.get( f"http://127.0.0.1:{self.port}/status/config", headers={"x-api-key": "test-api-key"}, ) as response: config = json.loads(await response.text())["config"] assert "admin.admin_insecure_mode" in config assert all(k not in config for k in [ "admin.admin_api_key", "multitenant.jwt_secret", "wallet.key", "wallet.rekey", "wallet.seed", "wallet.storage.creds", ]) assert config["admin.webhook_urls"] == [ "localhost:8123/abc", "localhost:8123/def", ] async def test_visit_shutting_down(self): settings = { "admin.admin_insecure_mode": True, } server = self.get_admin_server(settings) await server.start() async with self.client_session.get( f"http://127.0.0.1:{self.port}/shutdown", headers={}) as response: assert response.status == 200 async with self.client_session.get( f"http://127.0.0.1:{self.port}/status", headers={}) as response: assert response.status == 503 async with self.client_session.get( f"http://127.0.0.1:{self.port}/status/live", headers={}) as response: assert response.status == 200 await server.stop() async def test_server_health_state(self): settings = { "admin.admin_insecure_mode": True, } server = self.get_admin_server(settings) await server.start() async with self.client_session.get( f"http://127.0.0.1:{self.port}/status/live", headers={}) as response: assert response.status == 200 response_json = await response.json() assert response_json["alive"] async with self.client_session.get( f"http://127.0.0.1:{self.port}/status/ready", headers={}) as response: assert response.status == 200 response_json = await response.json() assert response_json["ready"] server.notify_fatal_error() async with self.client_session.get( f"http://127.0.0.1:{self.port}/status/live", headers={}) as response: assert response.status == 503 async with self.client_session.get( f"http://127.0.0.1:{self.port}/status/ready", headers={}) as response: assert response.status == 503 await server.stop()
class JupyterClient: log: BoundLoggerLazyProxy user: User session: ClientSession headers: Dict[str, str] xsrftoken: str jupyter_url: str def __init__(self, user: User, log: BoundLoggerLazyProxy, options: Dict[str, Any]): self.user = user self.log = log self.jupyter_base = options.get("nb_url", "/nb/") self.jupyter_url = Configuration.environment_url + self.jupyter_base self.xsrftoken = "".join( random.choices(string.ascii_uppercase + string.digits, k=16)) self.jupyter_options_form = options.get("jupyter_options_form", {}) self.headers = { "Authorization": "Bearer " + user.token, "x-xsrftoken": self.xsrftoken, } self.session = ClientSession(headers=self.headers) self.session.cookie_jar.update_cookies( BaseCookie({"_xsrf": self.xsrftoken})) __ansi_reg_exp = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]") @classmethod def _ansi_escape(cls, line: str) -> str: return cls.__ansi_reg_exp.sub("", line) async def hub_login(self) -> None: async with self.session.get(self.jupyter_url + "hub/login") as r: if r.status != 200: await self._raise_error("Error logging into hub", r) async def ensure_lab(self) -> None: self.log.info("Ensure lab") running = await self.is_lab_running() if running: await self.lab_login() else: await self.spawn_lab() async def lab_login(self) -> None: self.log.info("Logging into lab") lab_url = self.jupyter_url + f"user/{self.user.username}/lab" async with self.session.get(lab_url) as r: if r.status != 200: await self._raise_error("Error logging into lab", r) async def is_lab_running(self) -> bool: self.log.info("Is lab running?") hub_url = self.jupyter_url + "hub" async with self.session.get(hub_url) as r: if r.status != 200: self.log.error(f"Error {r.status} from {r.url}") spawn_url = self.jupyter_url + "hub/spawn" self.log.info(f"Going to {hub_url} redirected to {r.url}") if str(r.url) == spawn_url: return False return True async def spawn_lab(self) -> None: spawn_url = self.jupyter_url + "hub/spawn" pending_url = (self.jupyter_url + f"hub/spawn-pending/{self.user.username}") lab_url = self.jupyter_url + f"user/{self.user.username}/lab" # DM-23864: Do a get on the spawn URL even if I don't have to. async with self.session.get(spawn_url) as r: await r.text() async with self.session.post(spawn_url, data=self.jupyter_options_form, allow_redirects=False) as r: if r.status != 302: await self._raise_error("Spawn did not redirect", r) redirect_url = (self.jupyter_base + f"hub/spawn-pending/{self.user.username}") if r.headers["Location"] != redirect_url: await self._raise_error("Spawn didn't redirect to pending", r) # Jupyterlab will give up a spawn after 900 seconds, so we shouldn't # wait longer than that. max_poll_secs = 900 poll_interval = 15 retries = max_poll_secs / poll_interval while retries > 0: async with self.session.get(pending_url) as r: if str(r.url) == lab_url: self.log.info(f"Lab spawned, redirected to {r.url}") return if not r.ok: await self._raise_error("Error spawning", r) self.log.info(f"Still waiting for lab to spawn {r}") retries -= 1 await asyncio.sleep(poll_interval) raise Exception("Giving up waiting for lab to spawn!") async def delete_lab(self) -> None: headers = {"Referer": self.jupyter_url + "hub/home"} server_url = (self.jupyter_url + f"hub/api/users/{self.user.username}/server") self.log.info(f"Deleting lab for {self.user.username} at {server_url}") async with self.session.delete(server_url, headers=headers) as r: if r.status not in [200, 202, 204]: await self._raise_error("Error deleting lab", r) async def create_kernel(self, kernel_name: str = "LSST") -> str: kernel_url = (self.jupyter_url + f"user/{self.user.username}/api/kernels") body = {"name": kernel_name} async with self.session.post(kernel_url, json=body) as r: if r.status != 201: await self._raise_error("Error creating kernel", r) response = await r.json() return response["id"] async def run_python(self, kernel_id: str, code: str) -> str: kernel_url = ( self.jupyter_url + f"user/{self.user.username}/api/kernels/{kernel_id}/channels") msg_id = uuid4().hex msg = { "header": { "username": "", "version": "5.0", "session": "", "msg_id": msg_id, "msg_type": "execute_request", }, "parent_header": {}, "channel": "shell", "content": { "code": code, "silent": False, "store_history": False, "user_expressions": {}, "allow_stdin": False, }, "metadata": {}, "buffers": {}, } async with self.session.ws_connect(kernel_url) as ws: await ws.send_json(msg) while True: r = await ws.receive_json() self.log.debug(f"Recieved kernel message: {r}") msg_type = r["msg_type"] if msg_type == "error": error_message = "".join(r["content"]["traceback"]) raise NotebookException(self._ansi_escape(error_message)) elif (msg_type == "stream" and msg_id == r["parent_header"]["msg_id"]): return r["content"]["text"] elif msg_type == "execute_reply": status = r["content"]["status"] if status == "ok": return "" else: raise NotebookException( f"Error content status is {status}") def dump(self) -> dict: return { "cookies": [str(cookie) for cookie in self.session.cookie_jar], } async def _raise_error(self, msg: str, r: ClientResponse) -> None: raise Exception(f"{msg}: {r.status} {r.url}: {r.headers}")
class WebsocketClient: def __init__(self, loop): self.url = config.get('server.publicUrl') self.client_id = config.get('client.id') self.client_token = config.get('client.token') self.reports_frequency = config.get('client.reports.frequency') self.loop = loop self.session = None self.socket = None def __enter__(self): try: before_client_start.emit(client=self) headers = self._headers() self.session = ClientSession(loop=self.loop) self.socket = self.loop.run_until_complete(self.session.ws_connect(self.url, headers=headers)) after_client_start.emit(client=self) return self except: self._close() raise def __exit__(self, exc_type, exc_val, exc_tb): before_client_stop.emit(client=self) self._close() after_client_stop.emit(client=self) def _close(self): if self.socket: self.loop.run_until_complete(self.socket.close(code=WSCloseCode.GOING_AWAY)) if self.session: self.loop.run_until_complete(self.session.close()) def _headers(self): headers = {} headers[CLIENT_ID_HEADER] = str(self.client_id) headers[CLIENT_TOKEN_HEADER] = str(self.client_token) return headers def _process_confirm(self, action): try: print('Received confirmation of action "{}" with payload: {}'.format(action.name, action.payload), flush=True) action.after_confirm.emit(client=self, action=action) after_server_action_confirm.emit(client=self, action=action) except Exception as e: print('Processing confirmation of action "{}" failed: {}'.format(action.name, e), flush=True) def _process_action(self, action): try: print('Received action "{}" with payload: {}'.format(action.name, action.payload), flush=True) before_client_action_receive.emit(client=self, action=action) action.before_receive.emit(client=self, action=action) action.execute(self) action.after_receive.emit(client=self, action=action) after_client_action_receive.emit(client=self, action=action) self.socket.send_json(action.send_confirm()) except Exception as e: print('Executing action "{}" failed: {}'.format(action.name, e), flush=True) def _process_message(self, msg): if msg.type != WSMsgType.TEXT: return try: action = parse_client_action(msg.data) except Exception as e: print('Invalid message received: {}; Error: {}'.format(msg.data, e), flush=True) return if action.confirm: self._process_confirm(action) else: self._process_action(action) async def connect(self): self.send_label() self.send_reports() async for msg in self.socket: self._process_message(msg) def send_action(self, action): before_server_action_send.emit(client=self, action=action) action.before_send.emit(client=self, action=action) self.socket.send_json(action.send()) action.after_send.emit(client=self, action=action) after_server_action_send.emit(client=self, action=action) async def check_send_action(self, action): future = asyncio.Future() sent_action = action def listener(client, action): if client is not self: return if action.name != sent_action.name: return if action.action_id != sent_action.action_id: return if not future.done(): future.set_result(True) with after_server_action_confirm.connected(listener): self.send_action(sent_action) await asyncio.wait_for(future, ACTION_CONFIRM_TIMEOUT) def send_label(self): label = config.get('client.label') self.send_action(SaveLabelServerAction(label=label)) def send_reports(self): report = collect_report() self.send_action(SaveReportServerAction(report=report)) self.loop.call_later(self.reports_frequency, self.send_reports)
class WsConn(Conn): __slots__ = ('_is_sharing_session', '_session', '_ws_receive_timeout', '_ws_heartbeat', '_ws') # url 格式 ws://hostname:port/… 或者 wss://hostname:port/… def __init__( self, url: str, receive_timeout: Optional[float] = None, session: Optional[ClientSession] = None, ws_receive_timeout: Optional[float] = None, # 自动 ping pong 时候用的 ws_heartbeat: Optional[float] = None): # 自动 ping pong 时候用的 super().__init__(url, receive_timeout) result = urlparse(url) if result.scheme != 'ws' and result.scheme != 'wss': raise TypeError(f'url scheme must be websocket ({result.scheme})') self._url = url if session is None: self._is_sharing_session = False self._session = ClientSession() else: self._is_sharing_session = True self._session = session self._ws_receive_timeout = ws_receive_timeout self._ws_heartbeat = ws_heartbeat self._ws = None async def open(self) -> bool: try: self._ws = await asyncio.wait_for(self._session.ws_connect( self._url, receive_timeout=self._ws_receive_timeout, heartbeat=self._ws_heartbeat), timeout=3) except (ClientError, asyncio.TimeoutError): return False return True async def close(self) -> bool: if self._ws is not None: await self._ws.close() return True async def clean(self) -> None: if not self._is_sharing_session: await self._session.close() async def send_bytes(self, bytes_data: bytes) -> bool: try: await self._ws.send_bytes(bytes_data) except ClientError: return False except asyncio.CancelledError: return False return True async def read_bytes(self) -> Optional[bytes]: try: msg = await asyncio.wait_for(self._ws.receive(), timeout=self._receive_timeout) if msg.type == WSMsgType.BINARY: return msg.data except (ClientError, asyncio.TimeoutError): return None except asyncio.CancelledError: # print('asyncio.CancelledError', 'read_bytes') return None return None async def read_json(self) -> Any: try: msg = await asyncio.wait_for(self._ws.receive(), timeout=self._receive_timeout) if msg.type == WSMsgType.TEXT: return json.loads(msg.data) elif msg.type == WSMsgType.BINARY: return json.loads(msg.data.decode('utf8')) except (ClientError, asyncio.TimeoutError): return None except asyncio.CancelledError: # print('asyncio.CancelledError', 'read_json') return None return None async def read_exactly_bytes(self, n: int) -> Optional[bytes]: raise NotImplementedError( "Sorry, but I don't think we need this in WebSocket.") async def read_exactly_json(self, n: int) -> Any: raise NotImplementedError( "Sorry, but I don't think we need this in WebSocket.")
async def _run_forever( self, url: StrOrURL, session: aiohttp.ClientSession, *, send_str: Optional[Union[str, list[str]]] = None, send_bytes: Optional[Union[bytes, list[bytes]]] = None, send_json: Any = None, hdlr_str=None, hdlr_bytes=None, hdlr_json=None, auth=_Auth, **kwargs: Any, ) -> None: if all([hdlr_str is None, hdlr_json is None]): hdlr_json = pybotters.print_handler iscorofunc_str = asyncio.iscoroutinefunction(hdlr_str) iscorofunc_bytes = asyncio.iscoroutinefunction(hdlr_bytes) iscorofunc_json = asyncio.iscoroutinefunction(hdlr_json) while not session.closed: cooldown = asyncio.create_task(asyncio.sleep(60.0)) try: async with session.ws_connect(url, auth=auth, **kwargs) as ws: self.conneted = True self._event.set() if send_str is not None: if isinstance(send_str, list): await asyncio.gather( *[ws.send_str(item) for item in send_str] ) else: await ws.send_str(send_str) if send_bytes is not None: if isinstance(send_bytes, list): await asyncio.gather( *[ws.send_bytes(item) for item in send_bytes] ) else: await ws.send_bytes(send_bytes) if send_json is not None: if isinstance(send_json, list): await asyncio.gather( *[ws.send_json(item) for item in send_json] ) else: await ws.send_json(send_json) async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: if hdlr_str is not None: try: if iscorofunc_str: await hdlr_str(msg.data, ws) else: hdlr_str(msg.data, ws) except Exception as e: logger.exception(f"{pretty_modulename(e)}: {e}") if hdlr_json is not None: try: data = msg.json() except json.decoder.JSONDecodeError: pass else: try: if iscorofunc_json: await hdlr_json(data, ws) else: hdlr_json(data, ws) except Exception as e: logger.exception(f"{pretty_modulename(e)}: {e}") elif msg.type == aiohttp.WSMsgType.BINARY: if hdlr_bytes is not None: try: if iscorofunc_bytes: await hdlr_bytes(msg.data, ws) else: hdlr_bytes(msg.data, ws) except Exception as e: logger.exception(f"{pretty_modulename(e)}: {e}") elif msg.type == aiohttp.WSMsgType.ERROR: break except ( aiohttp.WSServerHandshakeError, aiohttp.ClientOSError, ConnectionResetError, ) as e: logger.warning(f"{pretty_modulename(e)}: {e}") self.conneted = False self._event.clear() await cooldown
class TestAdminServer(AsyncTestCase): async def setUp(self): self.message_results = [] self.webhook_results = [] self.port = 0 self.connector = TCPConnector(limit=16, limit_per_host=4) session_args = { "cookie_jar": DummyCookieJar(), "connector": self.connector } self.client_session = ClientSession(cookie_jar=DummyCookieJar(), connector=self.connector) async def tearDown(self): if self.client_session: await self.client_session.close() self.client_session = None async def test_debug_middleware(self): with async_mock.patch.object(test_module, "LOGGER", async_mock.MagicMock()) as mock_logger: mock_logger.isEnabledFor = async_mock.MagicMock(return_value=True) mock_logger.debug = async_mock.MagicMock() request = async_mock.MagicMock( method="GET", path_qs="/hello/world?a=1&b=2", match_info={"match": "info"}, text=async_mock.CoroutineMock(return_value="abc123"), ) handler = async_mock.CoroutineMock() await test_module.debug_middleware(request, handler) mock_logger.isEnabledFor.assert_called_once() assert mock_logger.debug.call_count == 3 def get_admin_server(self, settings: dict = None, context: InjectionContext = None) -> AdminServer: if not context: context = InjectionContext() if settings: context.update_settings(settings) # middleware is task queue xor collector: cover both over test suite task_queue = (settings or {}).pop("task_queue", None) plugin_registry = async_mock.MagicMock(test_module.PluginRegistry, autospec=True) plugin_registry.post_process_routes = async_mock.MagicMock() context.injector.bind_instance(test_module.PluginRegistry, plugin_registry) collector = Collector() context.injector.bind_instance(test_module.Collector, collector) self.port = unused_port() return AdminServer( "0.0.0.0", self.port, context, self.outbound_message_router, self.webhook_router, conductor_stop=async_mock.CoroutineMock(), task_queue=TaskQueue(max_active=4) if task_queue else None, conductor_stats=(None if task_queue else async_mock.CoroutineMock( return_value=[1, 2])), ) async def outbound_message_router(self, *args): self.message_results.append(args) def webhook_router(self, *args): self.webhook_results.append(args) async def test_start_stop(self): with self.assertRaises(AssertionError): await self.get_admin_server().start() settings = {"admin.admin_insecure_mode": False} with self.assertRaises(AssertionError): await self.get_admin_server(settings).start() settings = { "admin.admin_insecure_mode": True, "admin.admin_api_key": "test-api-key", } with self.assertRaises(AssertionError): await self.get_admin_server(settings).start() settings = { "admin.admin_insecure_mode": False, "admin.admin_api_key": "test-api-key", } server = self.get_admin_server(settings) await server.start() with async_mock.patch.object(server, "websocket_queues", async_mock.MagicMock()) as mock_wsq: mock_wsq.values = async_mock.MagicMock(return_value=[ async_mock.MagicMock(stop=async_mock.MagicMock()) ]) await server.stop() with async_mock.patch.object(web.TCPSite, "start", async_mock.CoroutineMock()) as mock_start: mock_start.side_effect = OSError("Failure to launch") with self.assertRaises(AdminSetupError): await self.get_admin_server(settings).start() async def test_responder_send(self): message = OutboundMessage(payload="{}") server = self.get_admin_server() await server.responder.send_outbound(message) assert self.message_results == [(server.context, message)] async def test_responder_webhook(self): server = self.get_admin_server() test_url = "target_url" test_attempts = 99 server.add_webhook_target( target_url=test_url, topic_filter=["*"], # cover vacuous filter max_attempts=test_attempts, ) test_topic = "test_topic" test_payload = {"test": "TEST"} with async_mock.patch.object(server, "websocket_queues", async_mock.MagicMock()) as mock_wsq: mock_wsq.values = async_mock.MagicMock(return_value=[ async_mock.MagicMock(authenticated=True, enqueue=async_mock.CoroutineMock()) ]) await server.responder.send_webhook(test_topic, test_payload) assert self.webhook_results == [(test_topic, test_payload, test_url, test_attempts)] server.remove_webhook_target(target_url=test_url) assert test_url not in server.webhook_targets async def test_import_routes(self): # this test just imports all default admin routes # for routes with associated tests, this shouldn't make a difference in coverage context = InjectionContext() context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) await DefaultContextBuilder().load_plugins(context) server = self.get_admin_server({"admin.admin_insecure_mode": True}, context) app = await server.make_application() async def test_register_external_plugin_x(self): context = InjectionContext() context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) with self.assertRaises(ValueError): builder = DefaultContextBuilder( settings={"external_plugins": "aries_cloudagent.nosuchmodule"}) await builder.load_plugins(context) async def test_visit_insecure_mode(self): settings = {"admin.admin_insecure_mode": True, "task_queue": True} server = self.get_admin_server(settings) await server.start() async with self.client_session.post( f"http://127.0.0.1:{self.port}/status/reset", headers={}) as response: assert response.status == 200 async with self.client_session.ws_connect( f"http://127.0.0.1:{self.port}/ws") as ws: result = await ws.receive_json() assert result["topic"] == "settings" for path in ( "", "plugins", "status", "status/live", "status/ready", "shutdown", # mock conductor has magic-mock stop() ): async with self.client_session.get( f"http://127.0.0.1:{self.port}/{path}", headers={}) as response: assert response.status == 200 await server.stop() async def test_visit_secure_mode(self): settings = { "admin.admin_insecure_mode": False, "admin.admin_api_key": "test-api-key", } server = self.get_admin_server(settings) await server.start() async with self.client_session.get( f"http://127.0.0.1:{self.port}/status", headers={"x-api-key": "wrong-key"}) as response: assert response.status == 401 async with self.client_session.get( f"http://127.0.0.1:{self.port}/status", headers={"x-api-key": "test-api-key"}, ) as response: assert response.status == 200 async with self.client_session.ws_connect( f"http://127.0.0.1:{self.port}/ws", headers={"x-api-key": "test-api-key"}) as ws: result = await ws.receive_json() assert result["topic"] == "settings" await server.stop() async def test_visit_shutting_down(self): settings = { "admin.admin_insecure_mode": True, } server = self.get_admin_server(settings) await server.start() async with self.client_session.get( f"http://127.0.0.1:{self.port}/shutdown", headers={}) as response: assert response.status == 200 async with self.client_session.get( f"http://127.0.0.1:{self.port}/status", headers={}) as response: assert response.status == 503 async with self.client_session.get( f"http://127.0.0.1:{self.port}/status/live", headers={}) as response: assert response.status == 200 await server.stop()
async def proxy_handler(req: web.Request) -> web.Response: sess = await get_session(req) if "container_name" not in sess.keys(): raise web.HTTPFound("/login") else: container_name = sess["container_name"] code_server_manager = CodeServerManager(container_name) await code_server_manager.find_or_create_container() reqH = req.headers.copy() base_url = f"http://{container_name}:8080" # Do web socket Stuff if ( reqH["connection"] == "Upgrade" and reqH["upgrade"] == "websocket" and req.method == "GET" ): ws_server = web.WebSocketResponse() await ws_server.prepare(req) print(f"##### WS_SERVER {ws_server}") client_session = ClientSession(cookies=req.cookies) path_qs_cleaned = req.path_qs.removeprefix("/devenv") async with client_session.ws_connect(base_url + path_qs_cleaned) as ws_client: print(f"##### WS_CLIENT {ws_client}") async def wsforward(ws_from, ws_to): async for msg in ws_from: print(f">>> msg: {msg}") mt = msg.type md = msg.data if mt == WSMsgType.TEXT: await ws_to.send_str(md) elif mt == WSMsgType.BINARY: await ws_to.send_bytes(md) elif mt == WSMsgType.PING: await ws_to.ping() elif mt == WSMsgType.PONG: await ws_to.pong() elif ws_to.closed: await ws_to.close(code=ws_to.close_code, message=msg.extra) else: raise ValueError(f"unexpected message type: {msg}") await asyncio.wait( [wsforward(ws_server, ws_client), wsforward(ws_client, ws_server)], return_when=asyncio.FIRST_COMPLETED, ) return ws_server else: # Do http proxy proxyPath = req.path_qs if proxyPath != "": proxyPath = ( proxyPath.removeprefix("/devenv") .removeprefix("devenv") .removeprefix("/") ) proxyPath = "/" + proxyPath async with client.request( req.method, base_url + proxyPath, allow_redirects=False, data=await req.read(), ) as res: headers = res.headers.copy() headers["service-worker-allowed"] = "/" body = await res.read() return web.Response(headers=headers, status=res.status, body=body)
class LimooDriver: _ALLOWED_CONNECTION_ATTEMPTS = 1000000 _RETRY_DELAY = 2 @staticmethod async def _receive_event(ws): while True: try: return await ws.receive_json() except ValueError: continue @staticmethod async def _get_text_body(response): try: return await response.text() except (ClientConnectionError, ClientPayloadError) as ex: raise LimooError from ex finally: await response.release() @staticmethod async def _get_json_body(response): response_text = await LimooDriver._get_text_body(response) try: return json.loads(response_text) except json.JSONDecodeError as ex: raise LimooError( 'Response body is not valid json: {resonse_text}') from ex def _with_auth(coro): @functools.wraps(coro) async def wrapper(self, *args, **kwargs): async with self._authlock: authenticated = False previous_slc = self._successful_login_count while True: try: return await coro(self, *args, **kwargs) except LimooAuthenticationError: if authenticated: raise async with self._authlock: if self._successful_login_count == previous_slc: await self._login() self._successful_login_count += 1 authenticated = True previous_slc = self._successful_login_count return wrapper def __init__(self, limoo_url, bot_username, bot_password, secure=True): # Catch a relatively common mistake and report an informative error assert not limoo_url.startswith(('http://', 'https://')), ( 'The URL of the Limoo server should not start with' f' "http://" or "https://". The received URL was "{limoo_url}"') self._credentials = { 'j_username': bot_username, 'j_password': bot_password, } if limoo_url.endswith('/'): limoo_url = limoo_url[:-1] http_url = f'http{"s" if secure else ""}://{limoo_url}' ws_url = f'ws{"s" if secure else ""}://{limoo_url}' self._login_url = f'{http_url}/Limonad/j_spring_security_check' self._api_url_prefix = f'{http_url}/Limonad/api/v1' self._fileop_url = f'{http_url}/fileserver/api/v1/files' self._websocket_url = f'{ws_url}/Limonad/websocket' self._client_session = ClientSession(cookie_jar=CookieJar(unsafe=True)) self._successful_login_count = 0 self._authlock = asyncio.Lock() self._listen_task = None self._event_handler = lambda event: None self.conversations = Conversations(self) self.files = Files(self) self.messages = Messages(self) self.users = Users(self) self.workspaces = Workspaces(self) async def close(self): if self._listen_task: self._listen_task.cancel() try: await self._listen_task except asyncio.CancelledError: pass await self._client_session.close() async def _login(self): await self._execute_request('POST', self._login_url, data=self._credentials) @_with_auth async def _execute_api_get(self, endpoint): return await self._execute_json_request('GET', endpoint) @_with_auth async def _execute_api_post(self, endpoint, body): return await self._execute_json_request('POST', endpoint, body=body) async def _execute_json_request(self, method, endpoint, *, body=None): return await self._get_json_body(await self._execute_request( method, f'{self._api_url_prefix}/{endpoint}', json=body)) @_with_auth async def _upload_file(self, path, name, mime_type): formdata = FormData(quote_fields=False) run_async = asyncio.get_running_loop().run_in_executor async with contextlib.AsyncExitStack() as stack: file = await run_async(None, open, path, 'rb') stack.push_async_callback(run_async, None, file.close) formdata.add_field(name, file, content_type=mime_type, filename=name) return await self._get_json_body(await self._execute_request( 'POST', self._fileop_url, data=formdata)) @_with_auth async def _download_file(self, hash, name): params = urllib.parse.urlencode({'hash': hash, 'name': name}) return StreamReader(await self._execute_request( 'GET', f'{self._fileop_url}?mode=download&{params}')) async def _execute_request(self, method, url, *, data=None, json=None): try: response = await self._client_session.request( method, url, data=data, json=json, params={"is_bot": "true"}) except ClientConnectionError as ex: raise LimooError('Connection Error') from ex status = response.status if status < 400: return response response_text = await self._get_text_body(response) if status == 401: raise LimooAuthenticationError else: raise LimooError( f'Request returned unsuccessfully with status {status} and body {response_text}' ) def set_event_handler(self, event_handler): if event_handler is not None and not callable(event_handler): raise ValueError( 'event_handler must either be a callable or None.') self._event_handler = event_handler if self._event_handler and not self._listen_task: self._listen_task = asyncio.create_task(self._listen()) elif not self._event_handler and self._listen_task: self._listen_task.cancel() self._listen_task = None async def _listen(self): _LOGGER.info('WebSocket listen task started.') cancel_ex = None max_delay = self._ALLOWED_CONNECTION_ATTEMPTS * (self._RETRY_DELAY + 1) while not cancel_ex: ws = None for retry_delay in range(self._RETRY_DELAY, max_delay + self._RETRY_DELAY, self._RETRY_DELAY): try: ws = await self._try_connecting() break except asyncio.CancelledError as ex: cancel_ex = ex break except Exception as ex: _LOGGER.error( 'Connecting the WebSocket failed with the following exception: %s', ex) if retry_delay == max_delay: break else: _LOGGER.info( 'Going to sleep for %d seconds before trying to connect a WebSocket.', retry_delay) try: await asyncio.sleep(retry_delay) except asyncio.CancelledError as ex: cancel_ex = ex break if not ws: break _LOGGER.info('Connected the WebSocket.') while True: try: event = await self._receive_event(ws) except asyncio.CancelledError as ex: cancel_ex = ex break except Exception as ex: _LOGGER.error( 'The connected WebSocket broke with the following exception: %s', ex) break _LOGGER.debug('Received an event.') try: self._event_handler(event) except Exception as ex: _LOGGER.error( 'Calling the event handler failed with the following exception: %s', ex) try: await ws.close() _LOGGER.info('The WebSocket was closed.') except asyncio.CancelledError as ex: cancel_ex = ex except Exception as ex: _LOGGER.error( 'Closing the WebSocket failed with the following exception: %s', ex) _LOGGER.info('WebSocket listen task ended.') if cancel_ex: raise cancel_ex @_with_auth async def _try_connecting(self): async with contextlib.AsyncExitStack() as stack: ws = await stack.enter_async_context( self._client_session.ws_connect(self._websocket_url, receive_timeout=70, heartbeat=60)) event = await self._receive_event(ws) if event.get('event') == 'authentication_failed': raise LimooAuthenticationError else: stack.pop_all() return ws del _with_auth