예제 #1
0
    def __init__(self, ctx, handshake_flow):
        super().__init__(ctx)
        self.handshake_flow = handshake_flow
        self.flow: WebSocketFlow = None

        self.client_frame_buffer = []
        self.server_frame_buffer = []

        self.connections: dict[object, WSConnection] = {}

        client_extensions = []
        server_extensions = []
        if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
            if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
                client_extensions = [PerMessageDeflate()]
                server_extensions = [PerMessageDeflate()]
        self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER)
        self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT)

        if client_extensions:
            client_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
        if server_extensions:
            server_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])

        request = Request(extensions=client_extensions, host=handshake_flow.request.host, target=handshake_flow.request.path)
        data = self.connections[self.server_conn].send(request)
        self.connections[self.client_conn].receive_data(data)

        event = next(self.connections[self.client_conn].events())
        assert isinstance(event, events.Request)

        data = self.connections[self.client_conn].send(AcceptConnection(extensions=server_extensions))
        self.connections[self.server_conn].receive_data(data)
        assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection)
예제 #2
0
def _make_handshake(
    request_headers: Headers,
    accept_headers: Optional[Headers] = None,
    subprotocol: Optional[str] = None,
    extensions: Optional[List[Extension]] = None,
) -> Tuple[h11.InformationalResponse, bytes]:
    client = h11.Connection(h11.CLIENT)
    server = WSConnection(SERVER)
    nonce = generate_nonce()
    server.receive_data(
        client.send(
            h11.Request(
                method="GET",
                target="/",
                headers=[
                    (b"Host", b"localhost"),
                    (b"Connection", b"Keep-Alive, Upgrade"),
                    (b"Upgrade", b"WebSocket"),
                    (b"Sec-WebSocket-Version", b"13"),
                    (b"Sec-WebSocket-Key", nonce),
                ] + request_headers,
            )))
    client.receive_data(
        server.send(
            AcceptConnection(
                extra_headers=accept_headers or [],
                subprotocol=subprotocol,
                extensions=extensions or [],
            )))
    event = client.next_event()
    return event, nonce
예제 #3
0
    def __init__(
        self,
        app: Type[ASGIFramework],
        loop: asyncio.AbstractEventLoop,
        config: Config,
        transport: asyncio.BaseTransport,
        *,
        upgrade_request: Optional[h11.Request] = None,
    ) -> None:
        super().__init__(loop, config, transport, "wsproto")
        self.stop_keep_alive_timeout()
        self.app = app
        self.connection = WSConnection(ConnectionType.SERVER)

        self.app_queue: asyncio.Queue = asyncio.Queue()
        self.response: Optional[dict] = None
        self.scope: Optional[dict] = None
        self.state = ASGIWebsocketState.HANDSHAKE
        self.task: Optional[asyncio.Future] = None

        self.buffer = WebsocketBuffer(self.config.websocket_max_message_size)

        if upgrade_request is not None:
            self.connection.initiate_upgrade_connection(
                upgrade_request.headers, upgrade_request.target)
            self.handle_events()
예제 #4
0
def test_upgrade_request() -> None:
    server = WSConnection(SERVER)
    server.initiate_upgrade_connection(
        [
            (b"Host", b"localhost"),
            (b"Connection", b"Keep-Alive, Upgrade"),
            (b"Upgrade", b"WebSocket"),
            (b"Sec-WebSocket-Version", b"13"),
            (b"Sec-WebSocket-Key", generate_nonce()),
            (b"X-Foo", b"bar"),
        ],
        "/",
    )
    event = next(server.events())
    event = cast(Request, event)

    assert event.extensions == []
    assert event.host == "localhost"
    assert event.subprotocols == []
    assert event.target == "/"
    headers = normed_header_dict(event.extra_headers)
    assert b"host" not in headers
    assert b"sec-websocket-extensions" not in headers
    assert b"sec-websocket-protocol" not in headers
    assert headers[b"connection"] == b"Keep-Alive, Upgrade"
    assert headers[b"sec-websocket-version"] == b"13"
    assert headers[b"upgrade"] == b"WebSocket"
    assert headers[b"x-foo"] == b"bar"
예제 #5
0
    async def start_server(self, sock: anyio.abc.SocketStream, filter=None):  # pylint: disable=W0622
        """Start a server WS connection on this socket.

        Filter: an async callable that gets passed the initial Request.
        It may return an AcceptConnection message, a bool, or a string (the
        subprotocol to use).
        Returns: the Request message.
        """
        assert self._scope is None
        self._scope = True
        self._sock = sock
        self._connection = WSConnection(ConnectionType.SERVER)

        try:
            event = await self._next_event()
            if not isinstance(event, Request):
                raise ConnectionError("Failed to establish a connection",
                                      event)
            msg = None
            if filter is not None:
                msg = await filter(event)
                if not msg:
                    msg = RejectConnection()
                elif msg is True:
                    msg = None
                elif isinstance(msg, str):
                    msg = AcceptConnection(subprotocol=msg)
            if not msg:
                msg = AcceptConnection(subprotocol=event.subprotocols[0])
            data = self._connection.send(msg)
            await self._sock.send_all(data)
            if not isinstance(msg, AcceptConnection):
                raise ConnectionError("Not accepted", msg)
        finally:
            self._scope = None
예제 #6
0
def _make_handshake(
    response_status,
    response_headers,
    subprotocols=None,
    extensions=None,
    auto_accept_key=True,
):
    client = WSConnection(CLIENT)
    server = h11.Connection(h11.SERVER)
    server.receive_data(
        client.send(
            Request(
                host="localhost",
                target="/",
                subprotocols=subprotocols or [],
                extensions=extensions or [],
            )
        )
    )
    request = server.next_event()
    if auto_accept_key:
        full_request_headers = normed_header_dict(request.headers)
        response_headers.append(
            (
                b"Sec-WebSocket-Accept",
                generate_accept_token(full_request_headers[b"sec-websocket-key"]),
            )
        )
    response = h11.InformationalResponse(
        status_code=response_status, headers=response_headers
    )
    client.receive_data(server.send(response))

    return list(client.events())
예제 #7
0
def get_case_count(server):
    uri = urlparse(server + '/getCaseCount')
    connection = WSConnection(CLIENT)
    sock = socket.socket()
    sock.connect((uri.hostname, uri.port or 80))

    sock.sendall(connection.send(Request(host=uri.netloc, target=uri.path)))

    case_count = None
    while case_count is None:
        data = sock.recv(65535)
        connection.receive_data(data)
        data = ""
        out_data = b""
        for event in connection.events():
            if isinstance(event, TextMessage):
                data += event.data
                if event.message_finished:
                    case_count = json.loads(data)
                    out_data += connection.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE))
            try:
                sock.sendall(out_data)
            except CONNECTION_EXCEPTIONS:
                break

    sock.close()
    return case_count
예제 #8
0
def update_reports(server, agent):
    uri = urlparse(server + '/updateReports?agent=%s' % agent)
    connection = WSConnection(CLIENT)
    sock = socket.socket()
    sock.connect((uri.hostname, uri.port or 80))

    sock.sendall(
        connection.send(
            Request(host=uri.netloc, target='%s?%s' % (uri.path, uri.query))))
    closed = False

    while not closed:
        data = sock.recv(65535)
        connection.receive_data(data)
        for event in connection.events():
            if isinstance(event, AcceptConnection):
                sock.sendall(
                    connection.send(
                        CloseConnection(code=CloseReason.NORMAL_CLOSURE)))
                try:
                    sock.close()
                except CONNECTION_EXCEPTIONS:
                    pass
                finally:
                    closed = True
예제 #9
0
    async def init_for_client(
        cls: Type["Transport"], stream: Stream, host: str, keepalive: Optional[int] = None
    ) -> "Transport":
        ws = WSConnection(ConnectionType.CLIENT)
        transport = cls(stream, ws, keepalive)

        # Because this is a client WebSocket, we need to initiate the connection
        # handshake by sending a Request event.
        await transport._net_send(
            Request(
                host=host,
                target=TRANSPORT_TARGET,
                extra_headers=[(b"User-Agent", USER_AGENT.encode())],
            )
        )

        # Get handshake answer
        event = await transport._next_ws_event()

        if isinstance(event, AcceptConnection):
            transport.logger.debug("WebSocket negotiation complete", ws_event=event)

        else:
            transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event)
            reason = f"Unexpected event during WebSocket handshake: {event}"
            raise TransportError(reason)

        return transport
예제 #10
0
def test_connection_send_state() -> None:
    client = WSConnection(CLIENT)
    assert client.state is ConnectionState.CONNECTING

    server = h11.Connection(h11.SERVER)
    server.receive_data(client.send(Request(
        host="localhost",
        target="/",
    )))
    headers = normed_header_dict(server.next_event().headers)
    response = h11.InformationalResponse(
        status_code=101,
        headers=[
            (b"connection", b"Upgrade"),
            (b"upgrade", b"WebSocket"),
            (
                b"Sec-WebSocket-Accept",
                generate_accept_token(headers[b"sec-websocket-key"]),
            ),
        ],
    )
    client.receive_data(server.send(response))
    assert len(list(client.events())) == 1
    assert client.state is ConnectionState.OPEN  # type: ignore # https://github.com/python/mypy/issues/9005

    with pytest.raises(LocalProtocolError) as excinfo:
        client.send(Request(host="localhost", target="/"))

    client.receive_data(b"foobar")
    assert len(list(client.events())) == 1
예제 #11
0
def handle_todo_post_save(sender, instance, created, **kwargs):
    if not hasattr(sender, 'APP_PORT'):
        return
    conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    conn.connect(('localhost', int(sender.APP_PORT)))

    ws = WSConnection(ConnectionType.CLIENT)
    print("hello")
    net_send(
        ws.send(Request(
            host=f"localhost:{sender.APP_PORT}",
            target=f"ws/todos")),
        conn
    )
    net_recv(ws, conn)
    handle_events(ws)

    net_send(ws.send(Message(data=str(instance.pk))), conn)
    net_recv(ws, conn)
    handle_events(ws)

    net_send(ws.send(CloseConnection(code=1000)), conn)
    net_recv(ws, conn)
    conn.shutdown(socket.SHUT_WR)
    net_recv(ws, conn)

    del sender.APP_PORT
예제 #12
0
    async def start_client(self,
                           sock: anyio.abc.SocketStream,
                           addr,
                           path: str,
                           headers: Optional[List] = None,
                           subprotocols: Optional[List[str]] = None):
        """Start a client WS connection on this socket.

        Returns: the AcceptConnection message.
        """
        self._sock = sock
        self._connection = WSConnection(ConnectionType.CLIENT)
        if headers is None:
            headers = []
        if subprotocols is None:
            subprotocols = []
        data = self._connection.send(
            Request(
                host=addr[0],
                target=path,
                extra_headers=headers,
                subprotocols=subprotocols))
        await self._sock.send_all(data)

        assert self._scope is None
        self._scope = True
        try:
            event = await self._next_event()
            if not isinstance(event, AcceptConnection):
                raise ConnectionError("Failed to establish a connection",
                                      event)
            return event
        finally:
            self._scope = None
예제 #13
0
def _make_connection_request(request_headers, method="GET"):
    # type: (List[Tuple[str, str]]) -> Request
    client = h11.Connection(h11.CLIENT)
    server = WSConnection(SERVER)
    server.receive_data(
        client.send(h11.Request(method=method, target="/", headers=request_headers))
    )
    return next(server.events())
예제 #14
0
def _make_connection_request(request_headers: Headers,
                             method: str = "GET") -> Request:
    client = h11.Connection(h11.CLIENT)
    server = WSConnection(SERVER)
    server.receive_data(
        client.send(
            h11.Request(method=method, target="/", headers=request_headers)))
    return next(server.events())  # type: ignore
예제 #15
0
 def __init__(self, host: str, port: int):
     self._check_init_arguments(host, port)
     self._host = host
     self._port = port
     self._server: StreamServer = None
     self._ws = WSConnection(ConnectionType.SERVER)
     self._client: socket = None  # client socket provided by the StreamServer
     self._running = True
예제 #16
0
파일: wsecho.py 프로젝트: drewja/werkzeug
def websocket(request):
    # The underlying socket must be provided by the server. Gunicorn and
    # Werkzeug's dev server are known to support this.
    stream = request.environ.get("werkzeug.socket")

    if stream is None:
        stream = request.environ.get("gunicorn.socket")

    if stream is None:
        raise InternalServerError()

    # Initialize the wsproto connection. Need to recreate the request
    # data that was read by the WSGI server already.
    ws = WSConnection(ConnectionType.SERVER)
    in_data = b"GET %s HTTP/1.1\r\n" % request.path.encode("utf8")

    for header, value in request.headers.items():
        in_data += f"{header}: {value}\r\n".encode()

    in_data += b"\r\n"
    ws.receive_data(in_data)
    running = True

    while True:
        out_data = b""

        for event in ws.events():
            if isinstance(event, WSRequest):
                out_data += ws.send(AcceptConnection())
            elif isinstance(event, CloseConnection):
                out_data += ws.send(event.response())
                running = False
            elif isinstance(event, Ping):
                out_data += ws.send(event.response())
            elif isinstance(event, TextMessage):
                # echo the incoming message back to the client
                if event.data == "quit":
                    out_data += ws.send(
                        CloseConnection(CloseReason.NORMAL_CLOSURE, "bye")
                    )
                    running = False
                else:
                    out_data += ws.send(Message(data=event.data))

        if out_data:
            stream.send(out_data)

        if not running:
            break

        in_data = stream.recv(4096)
        ws.receive_data(in_data)

    # The connection will be closed at this point, but WSGI still
    # requires a response.
    return Response("", status=204)
예제 #17
0
 def __init__(self,
              path: str,
              *,
              framework: ASGIFramework = echo_framework) -> None:
     self.client_stream, server_stream = trio.testing.memory_stream_pair()
     server_stream.socket = MockSocket()
     self.server = WebsocketServer(framework, Config(), server_stream)
     self.connection = WSConnection(ConnectionType.CLIENT)
     self.server.connection.receive_data(
         self.connection.send(Request(target=path, host="hypercorn")))
예제 #18
0
 def _establish_websocket_handshake(self, host: str, path: str,
                                    headers: Headers, extensions: List[str],
                                    sub_protocols: List[str]) -> None:
     self._ws = WSConnection(ConnectionType.CLIENT)
     headers = headers if headers is not None else []
     extensions = extensions if extensions is not None else []
     sub_protocols = sub_protocols if sub_protocols is not None else []
     request = Request(host=host,
                       target=path,
                       extra_headers=headers,
                       extensions=extensions,
                       subprotocols=sub_protocols)
     self._sock.sendall(self._ws.send(request))
예제 #19
0
 def __init__(
     self,
     path: str,
     event_loop: asyncio.AbstractEventLoop,
     *,
     framework: Type[ASGIFramework] = EchoFramework,
 ) -> None:
     self.transport = MockTransport()
     self.server = WebsocketServer(  # type: ignore
         framework, event_loop, Config(), self.transport)
     self.connection = WSConnection(ConnectionType.CLIENT)
     self.server.data_received(
         self.connection.send(Request(target=path, host="hypercorn")))
예제 #20
0
def handle_connection(stream: socket.socket) -> None:
    """
    Handle a connection.

    The server operates a request/response cycle, so it performs a synchronous
    loop:

    1) Read data from network into wsproto
    2) Get new events and handle them
    3) Send data from wsproto to network

    :param stream: a socket stream
    """
    ws = WSConnection(ConnectionType.SERVER)
    running = True

    while running:
        # 1) Read data from network
        in_data = stream.recv(RECEIVE_BYTES)
        print("Received {} bytes".format(len(in_data)))
        ws.receive_data(in_data)

        # 2) Get new events and handle them
        out_data = b""
        for event in ws.events():
            if isinstance(event, Request):
                # Negotiate new WebSocket connection
                print("Accepting WebSocket upgrade")
                out_data += ws.send(AcceptConnection())
            elif isinstance(event, CloseConnection):
                # Print log message and break out
                print("Connection closed: code={} reason={}".format(
                    event.code, event.reason))
                out_data += ws.send(event.response())
                running = False
            elif isinstance(event, TextMessage):
                # Reverse text and send it back to wsproto
                print("Received request and sending response")
                out_data += ws.send(Message(data=event.data[::-1]))
            elif isinstance(event, Ping):
                # wsproto handles ping events for you by placing a pong frame in
                # the outgoing buffer. You should not call pong() unless you want to
                # send an unsolicited pong frame.
                print("Received ping and sending pong")
                out_data += ws.send(event.response())
            else:
                print(f"Unknown event: {event!r}")

        # 4) Send data from wsproto to network
        print("Sending {} bytes".format(len(out_data)))
        stream.send(out_data)
예제 #21
0
def wsproto_demo(host, port):
    '''
    Demonstrate wsproto:

    0) Open TCP connection
    1) Negotiate WebSocket opening handshake
    2) Send a message and display response
    3) Send ping and display pong
    4) Negotiate WebSocket closing handshake

    :param stream: a socket stream
    '''

    # 0) Open TCP connection
    print('Connecting to {}:{}'.format(host, port))
    conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    conn.connect((host, port))

    # 1) Negotiate WebSocket opening handshake
    print('Opening WebSocket')
    ws = WSConnection(ConnectionType.CLIENT)
    # Because this is a client WebSocket, we need to initiate the connection
    # handshake by sending a Request event.
    net_send(ws.send(Request(host=host, target='server')), conn)
    net_recv(ws, conn)
    handle_events(ws)

    # 2) Send a message and display response
    message = "wsproto is great"
    print('Sending message: {}'.format(message))
    net_send(ws.send(Message(data=message)), conn)
    net_recv(ws, conn)
    handle_events(ws)

    # 3) Send ping and display pong
    payload = b"table tennis"
    print('Sending ping: {}'.format(payload))
    net_send(ws.send(Ping(payload=payload)), conn)
    net_recv(ws, conn)
    handle_events(ws)

    # 4) Negotiate WebSocket closing handshake
    print('Closing WebSocket')
    net_send(ws.send(CloseConnection(code=1000, reason='sample reason')), conn)
    # After sending the closing frame, we won't get any more events. The server
    # should send a reply and then close the connection, so we need to receive
    # twice:
    net_recv(ws, conn)
    conn.shutdown(socket.SHUT_WR)
    net_recv(ws, conn)
예제 #22
0
def _make_handshake_rejection(status_code, body=None):
    client = WSConnection(CLIENT)
    server = h11.Connection(h11.SERVER)
    server.receive_data(client.send(Request(host="localhost", target="/")))
    headers = []
    if body is not None:
        headers.append(("Content-Length", str(len(body))))
    client.receive_data(
        server.send(h11.Response(status_code=status_code, headers=headers))
    )
    if body is not None:
        client.receive_data(server.send(h11.Data(data=body)))
    client.receive_data(server.send(h11.EndOfMessage()))

    return list(client.events())
예제 #23
0
    async def init_for_server(cls, stream):
        ws = WSConnection(ConnectionType.SERVER)
        transport = cls(stream, ws)

        # Wait for client to init WebSocket handshake
        event = "Websocket handshake timeout"
        with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT):
            event = await transport._next_ws_event()

        if isinstance(event, Request):
            transport.logger.debug("Accepting WebSocket upgrade")
            await transport._net_send(AcceptConnection())
            return transport

        transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event)
        raise TransportError(f"Unexpected event during WebSocket handshake: {event}")
예제 #24
0
    def __init__(self,
                 sock=None,
                 connection_type=None,
                 receive_bytes=4096,
                 thread_class=threading.Thread,
                 event_class=threading.Event):
        self.sock = sock
        self.receive_bytes = receive_bytes
        self.input_buffer = []
        self.event = event_class()
        self.connected = False
        self.is_server = (connection_type == ConnectionType.SERVER)

        self.ws = WSConnection(connection_type)
        self.handshake()

        self.thread = thread_class(target=self._thread)
        self.thread.start()
        self.event.wait()
        self.event.clear()
예제 #25
0
    async def init_for_client(cls, stream: Stream, host: str) -> "Transport":
        ws = WSConnection(ConnectionType.CLIENT)
        transport = cls(stream, ws)

        # Because this is a client WebSocket, we need to initiate the connection
        # handshake by sending a Request event.
        await transport._net_send(Request(host=host, target=TRANSPORT_TARGET))

        # Get handshake answer
        event = await transport._next_ws_event()

        if isinstance(event, AcceptConnection):
            transport.logger.debug("WebSocket negotiation complete", ws_event=event)

        else:
            transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event)
            reason = f"Unexpected event during WebSocket handshake: {event}"
            raise TransportError(reason)

        return transport
예제 #26
0
    async def init_for_server(  # type: ignore[misc]
        cls: Type["Transport"], stream: Stream, upgrade_request: Optional[H11Request] = None
    ) -> "Transport":
        ws = WSConnection(ConnectionType.SERVER)
        if upgrade_request:
            ws.initiate_upgrade_connection(
                headers=upgrade_request.headers, path=upgrade_request.target
            )
        transport = cls(stream, ws)

        # Wait for client to init WebSocket handshake
        event: Union[str, Event] = "Websocket handshake timeout"
        with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT):
            event = await transport._next_ws_event()
        if isinstance(event, Request):
            transport.logger.debug("Accepting WebSocket upgrade")
            await transport._net_send(AcceptConnection())
            return transport

        transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event)
        raise TransportError(f"Unexpected event during WebSocket handshake: {event}")
예제 #27
0
def _make_handshake_rejection(
    status_code: int, body: Optional[bytes] = None
) -> List[Event]:
    client = h11.Connection(h11.CLIENT)
    server = WSConnection(SERVER)
    nonce = generate_nonce()
    server.receive_data(
        client.send(
            h11.Request(
                method="GET",
                target="/",
                headers=[
                    (b"Host", b"localhost"),
                    (b"Connection", b"Keep-Alive, Upgrade"),
                    (b"Upgrade", b"WebSocket"),
                    (b"Sec-WebSocket-Version", b"13"),
                    (b"Sec-WebSocket-Key", nonce),
                ],
            )
        )
    )
    if body is not None:
        client.receive_data(
            server.send(
                RejectConnection(
                    headers=[(b"content-length", b"%d" % len(body))],
                    status_code=status_code,
                    has_body=True,
                )
            )
        )
        client.receive_data(server.send(RejectData(data=body)))
    else:
        client.receive_data(server.send(RejectConnection(status_code=status_code)))
    events = []
    while True:
        event = client.next_event()
        events.append(event)
        if isinstance(event, h11.EndOfMessage):
            return events
예제 #28
0
def new_conn(sock: socket.socket) -> None:
    global count
    print("test_server.py received connection {}".format(count))
    count += 1
    ws = WSConnection(SERVER)
    closed = False
    while not closed:
        try:
            data: Optional[bytes] = sock.recv(65535)
        except socket.error:
            data = None

        ws.receive_data(data or None)

        outgoing_data = b""
        for event in ws.events():
            if isinstance(event, Request):
                outgoing_data += ws.send(
                    AcceptConnection(extensions=[PerMessageDeflate()]))
            elif isinstance(event, Message):
                outgoing_data += ws.send(
                    Message(data=event.data,
                            message_finished=event.message_finished))
            elif isinstance(event, Ping):
                outgoing_data += ws.send(event.response())
            elif isinstance(event, CloseConnection):
                closed = True
                if ws.state is not ConnectionState.CLOSED:
                    outgoing_data += ws.send(event.response())

        if not data:
            closed = True

        try:
            sock.sendall(outgoing_data)
        except socket.error:
            closed = True

    sock.close()
예제 #29
0
    async def init_for_server(cls, stream, first_request_data=None):
        ws = WSConnection(ConnectionType.SERVER)
        try:
            if first_request_data:
                ws.receive_data(first_request_data)
            transport = cls(stream, ws)

            # Wait for client to init WebSocket handshake
            event = "Websocket handshake timeout"
            with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT):
                event = await transport._next_ws_event()

        except RemoteProtocolError as exc:
            raise TransportError(f"Invalid WebSocket query: {exc}") from exc

        if isinstance(event, Request):
            transport.logger.debug("Accepting WebSocket upgrade")
            await transport._net_send(AcceptConnection())
            return transport

        transport.logger.warning("Unexpected event during WebSocket handshake",
                                 ws_event=event)
        raise TransportError(
            f"Unexpected event during WebSocket handshake: {event}")
예제 #30
0
    def __init__(
        self,
        app: ASGIFramework,
        config: Config,
        stream: trio.abc.Stream,
        *,
        upgrade_request: Optional[h11.Request] = None,
    ) -> None:
        super().__init__(stream, "wsproto")
        self.app = app
        self.config = config
        self.connection = WSConnection(ConnectionType.SERVER)
        self.response: Optional[dict] = None
        self.scope: Optional[dict] = None
        self.send_lock = trio.Lock()
        self.state = ASGIWebsocketState.HANDSHAKE

        self.buffer = WebsocketBuffer(self.config.websocket_max_message_size)
        self.app_send_channel, self.app_receive_channel = trio.open_memory_channel(
            10)

        if upgrade_request is not None:
            self.connection.initiate_upgrade_connection(
                upgrade_request.headers, upgrade_request.target)