Exemple #1
0
async def is_websocket_active(ws: WebSocket) -> bool:
    if not (ws.application_state == WebSocketState.CONNECTED and ws.client_state == WebSocketState.CONNECTED):
        return False
    try:
        await asyncio.wait_for(ws.send_json({'type': 'ping'}), HEART_BEAT_INTERVAL)
        message = await asyncio.wait_for(ws.receive_json(), HEART_BEAT_INTERVAL)
        assert message['type'] == 'pong'
    except BaseException:  # asyncio.TimeoutError and ws.close()
        return False
    return True
Exemple #2
0
async def broadcast(socketList: WebSocket, message: str):
    """
    Broadcasts the same message across all sockets

    Arguments:
        socketList {List: [websockets]} -- list of all active websockets
        message {str} -- json message to send
    """
    for i in range(len(socketList) - 1, -1, -1):
        try:
            await socketList[i].send_text(message)
        except:
            socketList.pop(i)
Exemple #3
0
    def __init__(
        self,
        scope: Scope,
        receive: Receive,
        send: Send,
        value_type: Optional[str] = None,
        receive_type: Optional[str] = None,
        send_type: Optional[str] = None,
        caught_close_codes: Optional[Tuple[int, ...]] = None,
    ):
        # NOTE: we use composition over inheritance here, because
        # we want to redefine `receive()` and `send()` but Starlette's
        # WebSocket class uses those in many other functions, which we
        # do not need / want to re-implement.
        # The compromise is the definition of delegated methods below.
        self._ws = StarletteWebSocket(scope, receive=receive, send=send)

        if caught_close_codes is None:
            caught_close_codes = (1000, 1001)
        if caught_close_codes is all:
            caught_close_codes = tuple(WEBSOCKET_CLOSE_CODES)
        self.caught_close_codes = caught_close_codes

        if value_type is not None:
            receive_type = send_type = value_type
        else:
            receive_type = receive_type or self.__default_receive_type__
            send_type = send_type or self.__default_send_type__

        self.receive_type = receive_type
        self.send_type = send_type
Exemple #4
0
    async def on_connect(self, websocket: WebSocket):
        await websocket.accept()
        if not 'token' in self.scope['path_params']:
            await websocket.close()
            return

        user = await models.User.filter(
            token=self.scope['path_params']['token']
        ).prefetch_related('data', 'items').get_or_none()
        if not user:
            await websocket.close()
            return
        websocket.user_id = user.id
        await user.data.processing()
        await user.data.save()
        send_data = await models.UserDataPydanic.from_tortoise_orm(user.data)

        await connections.add(UserConnect(user, websocket))
        await connections.notify(user.id, {
            'type': 'sync',
            'data': send_data.dict()
        })
        await connections.notify(user.id, {
            'type': 'items',
            'items': user.items
        })
Exemple #5
0
async def ws_with_auth(websocket):
    websocket = WebSocket(scope=websocket.scope,
                          receive=websocket.receive,
                          send=websocket.send)
    await websocket.accept()
    await websocket.send_text('Authentication valid')
    await websocket.close()
Exemple #6
0
        async def awaitable(receive: Receive, send: Send) -> None:
            session = WebSocket(scope, receive=receive, send=send)
            kwargs = scope.get("kwargs", {})

            injected_func = await injector.inject(func)

            await injected_func(session, **kwargs)
Exemple #7
0
    async def _dispatch_request(self, req, **options):
        # Set formats on Request object.
        req.formats = self.formats

        # Get the route.
        route = self.path_matches_route(req.url.path)
        route = self.routes.get(route)

        if route:
            if route.uses_websocket:
                resp = WebSocket(**options)
            else:
                resp = models.Response(req=req, formats=self.formats)

            for before_request in self.before_requests:
                await self._execute_route(route=before_request, req=req, resp=resp)

            await self._execute_route(route=route, req=req, resp=resp, **options)
        else:
            resp = models.Response(req=req, formats=self.formats)
            self.default_response(req, resp, notfound=True)
        self.default_response(req, resp)

        self._prepare_session(resp)
        self._prepare_cookies(resp)

        return resp
Exemple #8
0
 async def app(scope: Scope, receive: Receive, send: Send) -> None:
     websocket = WebSocket(scope, receive=receive, send=send)
     await websocket.accept()
     async with anyio.create_task_group() as task_group:
         task_group.start_soon(reader, websocket)
         await writer(websocket)
     await websocket.close()
    async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None:
        assert scope["type"] == "websocket"

        websocket = WebSocket(scope, receive=receive, send=send)
        await websocket.accept(subprotocol="graphql-ws")
        await self._send_message(websocket, "connection_ack")

        # TODO: we should check that this is a proper connection init message
        await websocket.receive_json()
        data = await websocket.receive_json()

        id_ = data.get("id", "1")
        payload = data.get("payload", {})

        data = await self.execute(
            payload["query"],
            payload["variables"],
            operation_name=payload["operationName"],
        )

        async for result in data:
            # TODO: send errors if any

            await self._send_message(websocket, "data", {"data": result.data},
                                     id_)

        await self._send_message(websocket, "complete")
        await websocket.close()
Exemple #10
0
def test_websocket_scope_interface():
    """
    A WebSocket can be instantiated with a scope, and presents a `Mapping`
    interface.
    """
    async def mock_receive():
        pass  # pragma: no cover

    async def mock_send(message):
        pass  # pragma: no cover

    websocket = WebSocket(
        {
            "type": "websocket",
            "path": "/abc/",
            "headers": []
        },
        receive=mock_receive,
        send=mock_send,
    )
    assert websocket["type"] == "websocket"
    assert dict(websocket) == {
        "type": "websocket",
        "path": "/abc/",
        "headers": []
    }
    assert len(websocket) == 3
Exemple #11
0
 async def asgi(receive, send):
     nonlocal close_code
     websocket = WebSocket(scope, receive=receive, send=send)
     await websocket.accept()
     try:
         await websocket.receive_text()
     except WebSocketDisconnect as exc:
         close_code = exc.code
Exemple #12
0
    async def __call__(self, scope, receive, send):
        ws = WebSocket(scope, receive, send)

        before_requests = scope.get("before_requests", [])
        for before_request in before_requests.get("ws", []):
            await before_request(ws)

        await self.endpoint(ws)
Exemple #13
0
 async def app(scope: Scope, receive: Receive, send: Send) -> None:
     nonlocal close_code
     websocket = WebSocket(scope, receive=receive, send=send)
     await websocket.accept()
     try:
         await websocket.receive_text()
     except WebSocketDisconnect as exc:
         close_code = exc.code
Exemple #14
0
 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
     assert scope["type"] == "websocket"
     ws = WebSocket(scope, receive, send)
     await ws.accept()
     await run_until_first_complete(
         (self._watch_reloads, {"ws": ws}),
         (self._wait_client_disconnect, {"ws": ws}),
     )
Exemple #15
0
 def __init__(self, scope: Scope, receive: Receive, send: Send):
     """Do not use manually."""
     self._connection = WebSocket(scope, receive, send)
     self.state = addict.Dict()
     self.reraise = False
     self.closed = False
     self.headers = self._connection.headers
     self._queries = None
Exemple #16
0
async def websocket_endpoint(websocket: WebSocket):
    if "origin" not in websocket.headers or websocket.headers["origin"] not in Configuration.CORS_ORGINS:
        await websocket.accept()
        await websocket.send_text("it's too late in the evening to come up with a with a punny denied message so this will do for now")
        await websocket.close()
        return
    if "token" in websocket.cookies:
        token = websocket.cookies["token"]
        info = await Auth.get_token_info(token)
        if info is not None:
            websocket.auth_info = info
            if info.user.id not in socket_by_user:
                socket_by_user[info.user.id] = list()
            socket_by_user[info.user.id].append(websocket)

    # wrap in try except to make sure we can cleanup no matter what goes wrong
    websocket.active_subscriptions = dict()
    try:
        await websocket.accept()
        print("Websocket accepted")
        while websocket.application_state == WebSocketState.CONNECTED and websocket.client_state == WebSocketState.CONNECTED:
            try:
                data = await websocket.receive_json()
                if data["type"] not in handlers:
                    await websocket.send_json(dict(type="error", content="Unknown type!"))
                else:
                    await handlers[data["type"]](websocket, data.get("message", {}))
            except WebSocketDisconnect:
                break
            except Exception as ex:
                if isinstance(ex, FailedException):
                    await websocket.send_json(dict(type="error", content="Seems the bot failed to process your query, please try again later"))
                elif isinstance(ex, UnauthorizedException):
                    await websocket.send_json(dict(type="error", content="Access denied!"))
                elif isinstance(ex, NoReplyException):
                    await websocket.send_json(dict(type="error", content="Unable to communicate with GearBot, please try again later"))
                else:
                    await websocket.send_json(dict(type="error", content="Something went wrong!"))
                    raise ex
    except Exception as ex:
        await cleanup(websocket)
        raise ex
    else:
        await cleanup(websocket)
    print("Websocket closed")
Exemple #17
0
def test_websocket_scope_interface():
    """
    A WebSocket can be instantiated with a scope, and presents a `Mapping`
    interface.
    """
    websocket = WebSocket({"type": "websocket", "path": "/abc/", "headers": []})
    assert websocket["type"] == "websocket"
    assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []}
    assert len(websocket) == 3
Exemple #18
0
 async def asgi(receive, send):
     websocket = WebSocket(scope, receive=receive, send=send)
     await websocket.accept()
     asyncio.ensure_future(respond(websocket))
     try:
         # this will block as the client does not send us data
         # it should not prevent `respond` from executing though
         await websocket.receive_json()
     except WebSocketDisconnect:
         pass
Exemple #19
0
    async def websocket(self, scope: Scope, receive: Receive,
                        send: Send) -> None:
        websocket = WebSocket(scope, receive=receive, send=send)

        module = self.get_view(websocket.url.path)
        if module is None or not hasattr(module, "Socket"):
            await WebSocketClose(WS_1001_GOING_AWAY)(scope, receive, send)
            return

        await getattr(module, "Socket")(websocket)
def __send_message(websocket: WebSocket, message: str) -> None:
    """
    Send a text message to the given websocket

    Args:
        websocket: Client with will be messaged
        message: Message to be send
    """
    if websocket.client_state == WebSocketState.CONNECTED:
        asyncio.run(websocket.send_json({'data': message}))
Exemple #21
0
 async def get_thing(self, thing_id):
     """
     Get the thing this request is for.
     things -- list of Things managed by this server
     thing_id -- ID of the thing to get, in string form
     Returns the thing, or None if not found.
     """
     websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
     things = websocket.app.state.things
     return await things.get_thing(thing_id)
 async def asgi(receive, send):
     websocket = WebSocket(scope, receive=receive, send=send)
     await websocket.accept()
     async with anyio.create_task_group() as task_group:
         task_group.start_soon(respond, websocket)
         try:
             # this will block as the client does not send us data
             # it should not prevent `respond` from executing though
             await websocket.receive_json()
         except WebSocketDisconnect:
             pass
Exemple #23
0
    async def handle_websocket(self, scope: Scope, receive: Receive, send: Send):
        websocket = WebSocket(scope=scope, receive=receive, send=send)

        subscriptions: typing.Dict[str, typing.AsyncGenerator] = {}
        tasks = {}

        await websocket.accept(subprotocol="graphql-ws")

        try:
            while (
                websocket.client_state != WebSocketState.DISCONNECTED
                and websocket.application_state != WebSocketState.DISCONNECTED
            ):
                message = await websocket.receive_json()

                operation_id = message.get("id")
                message_type = message.get("type")

                if message_type == GQL_CONNECTION_INIT:
                    await websocket.send_json({"type": GQL_CONNECTION_ACK})

                    if self.keep_alive:
                        self._keep_alive_task = asyncio.create_task(
                            self.handle_keep_alive(websocket)
                        )
                elif message_type == GQL_CONNECTION_TERMINATE:
                    await websocket.close()
                elif message_type == GQL_START:
                    async_result = await self.start_subscription(
                        message.get("payload"), operation_id, websocket
                    )

                    subscriptions[operation_id] = async_result

                    tasks[operation_id] = asyncio.create_task(
                        self.handle_async_results(async_result, operation_id, websocket)
                    )
                elif message_type == GQL_STOP:  # pragma: no cover
                    if operation_id not in subscriptions:
                        return

                    await subscriptions[operation_id].aclose()
                    tasks[operation_id].cancel()
                    del tasks[operation_id]
                    del subscriptions[operation_id]
        except WebSocketDisconnect:  # pragma: no cover
            pass
        finally:
            if self._keep_alive_task:
                self._keep_alive_task.cancel()

            for operation_id in subscriptions:
                await subscriptions[operation_id].aclose()
                tasks[operation_id].cancel()
Exemple #24
0
    async def _dispatch_ws(self, scope, receive, send):
        ws = WebSocket(scope=scope, receive=receive, send=send)

        route = self.path_matches_route(ws.url.path)
        route = self.routes.get(route)

        if route:
            for before_request in self.before_ws_requests:
                await self.background(before_request, ws=ws)
            await self.background(route.endpoint, ws)
        else:
            await send({"type": "websocket.close", "code": 1000})
Exemple #25
0
        async def asgi(receive, send):
            nonlocal scope, self

            if scope["type"] == "websocket":
                ws = WebSocket(scope=scope, receive=receive, send=send)
                await self._dispatch_ws(ws)
            else:
                req = models.Request(scope, receive=receive, api=self)
                resp = await self._dispatch_request(req,
                                                    scope=scope,
                                                    send=send,
                                                    receive=receive)
                await resp(receive, send)
Exemple #26
0
 async def __call__(self, receive, send):
     redis_host = 'redis://{}'.format(
         self.scope.get('app').settings.REDIS_HOST)
     self.pub = await aioredis.create_redis(redis_host)
     self.sub = await aioredis.create_redis(redis_host)
     websocket = WebSocket(self.scope, receive=receive, send=send)
     await self.on_connect(websocket)
     await asyncio.gather(
         self.listen_ws(websocket),
         self.listen_redis(websocket, [
             self.get_channel_name(constants.__ALL__),
             self.get_channel_name(self.channel_name)
         ]))
Exemple #27
0
async def consumption_endpoint(websocket: WebSocket):
    registered_ids = []

    def generate_messages():
        _messages = []
        for device in sorted(registered_ids):
            usage = random.randint(*ranges[device])
            _messages.append({"id": device, "usage": usage})
        return _messages

    await websocket.accept()
    next_message = datetime.now()
    try:
        while True:
            try:
                raw_data = await asyncio.wait_for(websocket.receive_text(),
                                                  0.01)
                data = json.loads(raw_data)
            except asyncio.TimeoutError:
                pass
            except json.JSONDecodeError as e:
                await websocket.send_json(
                    {"error": f"JSONDecodeError: {str(e)}"})
            else:
                if not isinstance(data, dict):
                    await websocket.send_json({"error": str(type(data))})
                elif "machine_id" not in data:
                    await websocket.send_json(
                        {"error": "KeyError: 'machine_id'"})
                elif data["machine_id"] in registered_ids:
                    await websocket.send_json({
                        "error":
                        f"DuplicateKeyError: {data['machine_id']} already registered"
                    })
                elif data["machine_id"] not in machine_ids:
                    await websocket.send_json({
                        "error":
                        f"InvalidIDError: {data['machine_id']} is not valid"
                    })
                else:
                    registered_ids.append(data["machine_id"])

            if datetime.now() > next_message:
                next_message += timedelta(seconds=1)
                messages = generate_messages()
                for message in messages:
                    await websocket.send_json(message)
    except WebSocketDisconnect:
        await websocket.close()
 async def asgi(receive, send):
     websocket = WebSocket(scope, receive=receive, send=send)
     queue = asyncio.Queue()
     await websocket.accept()
     await run_until_first_complete(
         (reader, {
             "websocket": websocket,
             "queue": queue
         }),
         (writer, {
             "websocket": websocket,
             "queue": queue
         }),
     )
     await websocket.close()
Exemple #29
0
    async def dispatch(self) -> None:
        app = self.scope["app"]
        websocket = WebSocket(self.scope, self.receive, self.send)

        route, route_scope = app.router.get_route_from_scope(self.scope)

        state = {
            "scope": self.scope,
            "receive": self.receive,
            "send": self.send,
            "exc": None,
            "app": app,
            "path_params": route_scope["path_params"],
            "route": route,
            "websocket": websocket,
            "websocket_encoding": self.encoding,
            "websocket_code": status.WS_1000_NORMAL_CLOSURE,
            "websocket_message": None,
        }

        try:
            on_connect = await app.injector.inject(self.on_connect, state)
            await on_connect()
        except Exception as e:
            raise exceptions.WebSocketConnectionException(
                "Error connecting socket") from e

        try:
            state["websocket_message"] = await websocket.receive()

            while websocket.client_state == WebSocketState.CONNECTED:
                on_receive = await app.injector.inject(self.on_receive, state)
                await on_receive()
                state["websocket_message"] = await websocket.receive()

            state["websocket_code"] = int(state["websocket_message"].get(
                "code", status.WS_1000_NORMAL_CLOSURE))
        except exceptions.WebSocketException as e:
            state["websocket_code"] = e.close_code
        except Exception as e:
            state["websocket_code"] = status.WS_1011_INTERNAL_ERROR
            raise e from None
        finally:
            on_disconnect = await app.injector.inject(self.on_disconnect,
                                                      state)
            await on_disconnect()
 async def __call__(self, scope: Scope, receive: Receive,
                    send: Send) -> None:
     if scope["type"] == "http":
         request = Request(scope=scope, receive=receive)
         response: Response
         if request.method == "GET" and self.playground:
             response = HTMLResponse(PLAYGROUND_HTML)
         elif request.method == "POST":
             response = await self._handle_http_request(request)
         else:
             response = Response(status_code=405)
         await response(scope, receive, send)
     elif scope["type"] == "websocket":
         websocket = WebSocket(scope=scope, receive=receive, send=send)
         await self._run_websocket_server(websocket)
     else:
         raise ValueError(f"Unsupported scope type: ${scope['type']}")