Example #1
0
    async def start_server(self, sock: anyio.abc.SocketStream, filter=None):  # pylint: disable=W0622
        """Start a server WS connection on this socket.

        Filter: an async callable that gets passed the initial Request.
        It may return an AcceptConnection message, a bool, or a string (the
        subprotocol to use).
        Returns: the Request message.
        """
        assert self._scope is None
        self._scope = True
        self._sock = sock
        self._connection = WSConnection(ConnectionType.SERVER)

        try:
            event = await self._next_event()
            if not isinstance(event, Request):
                raise ConnectionError("Failed to establish a connection",
                                      event)
            msg = None
            if filter is not None:
                msg = await filter(event)
                if not msg:
                    msg = RejectConnection()
                elif msg is True:
                    msg = None
                elif isinstance(msg, str):
                    msg = AcceptConnection(subprotocol=msg)
            if not msg:
                msg = AcceptConnection(subprotocol=event.subprotocols[0])
            data = self._connection.send(msg)
            await self._sock.send_all(data)
            if not isinstance(msg, AcceptConnection):
                raise ConnectionError("Not accepted", msg)
        finally:
            self._scope = None
Example #2
0
    def __init__(self, ctx, handshake_flow):
        super().__init__(ctx)
        self.handshake_flow = handshake_flow
        self.flow: WebSocketFlow = None

        self.client_frame_buffer = []
        self.server_frame_buffer = []

        self.connections: dict[object, WSConnection] = {}

        client_extensions = []
        server_extensions = []
        if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
            if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
                client_extensions = [PerMessageDeflate()]
                server_extensions = [PerMessageDeflate()]
        self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER)
        self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT)

        if client_extensions:
            client_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
        if server_extensions:
            server_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])

        request = Request(extensions=client_extensions, host=handshake_flow.request.host, target=handshake_flow.request.path)
        data = self.connections[self.server_conn].send(request)
        self.connections[self.client_conn].receive_data(data)

        event = next(self.connections[self.client_conn].events())
        assert isinstance(event, events.Request)

        data = self.connections[self.client_conn].send(AcceptConnection(extensions=server_extensions))
        self.connections[self.server_conn].receive_data(data)
        assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection)
Example #3
0
def test_handshake_extra_accept_headers() -> None:
    events = _make_handshake(
        101,
        [(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket"),
         (b"X-Foo", b"bar")],
    )
    assert events == [AcceptConnection(extra_headers=[(b"x-foo", b"bar")])]
Example #4
0
 def _handle_events(self):
     keep_going = True
     out_data = b''
     try:
         for event in self.ws.events():
             if isinstance(event, Request):
                 out_data += self.ws.send(AcceptConnection())
             elif isinstance(event, CloseConnection):
                 if self.is_server:
                     out_data += self.ws.send(event.response())
                 self.event.set()
                 keep_going = False
             elif isinstance(event, Ping):
                 out_data += self.ws.send(event.response())
             elif isinstance(event, TextMessage):
                 self.input_buffer.append(event.data)
                 self.event.set()
             elif isinstance(event, BytesMessage):
                 self.input_buffer.append(event.data)
                 self.event.set()
         if out_data:
             self.sock.send(out_data)
         return keep_going
     except:
         return False
Example #5
0
 async def asgi_send(self, message: dict) -> None:
     """Called by the ASGI instance to send a message."""
     if message["type"] == "websocket.accept" and self.state == ASGIWebsocketState.HANDSHAKE:
         await self.asend(AcceptConnection(extensions=[PerMessageDeflate()]))
         self.state = ASGIWebsocketState.CONNECTED
         self.config.access_logger.access(
             self.scope, {"status": 101, "headers": []}, time() - self.start_time
         )
     elif (
         message["type"] == "websocket.http.response.start"
         and self.state == ASGIWebsocketState.HANDSHAKE
     ):
         self.response = message
         self.config.access_logger.access(self.scope, self.response, time() - self.start_time)
     elif message["type"] == "websocket.http.response.body" and self.state in {
         ASGIWebsocketState.HANDSHAKE,
         ASGIWebsocketState.RESPONSE,
     }:
         await self._asgi_send_rejection(message)
     elif message["type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED:
         data: Union[bytes, str]
         if message.get("bytes") is not None:
             await self.asend(BytesMessage(data=bytes(message["bytes"])))
         elif not isinstance(message["text"], str):
             raise TypeError(f"{message['text']} should be a str")
         else:
             await self.asend(TextMessage(data=message["text"]))
     elif message["type"] == "websocket.close" and self.state == ASGIWebsocketState.HANDSHAKE:
         await self.send_http_error(403)
         self.state = ASGIWebsocketState.HTTPCLOSED
     elif message["type"] == "websocket.close":
         await self.asend(CloseConnection(code=int(message["code"])))
         self.state = ASGIWebsocketState.CLOSED
     else:
         raise UnexpectedMessage(self.state, message["type"])
Example #6
0
def _make_handshake(
    request_headers: Headers,
    accept_headers: Optional[Headers] = None,
    subprotocol: Optional[str] = None,
    extensions: Optional[List[Extension]] = None,
) -> Tuple[h11.InformationalResponse, bytes]:
    client = h11.Connection(h11.CLIENT)
    server = WSConnection(SERVER)
    nonce = generate_nonce()
    server.receive_data(
        client.send(
            h11.Request(
                method="GET",
                target="/",
                headers=[
                    (b"Host", b"localhost"),
                    (b"Connection", b"Keep-Alive, Upgrade"),
                    (b"Upgrade", b"WebSocket"),
                    (b"Sec-WebSocket-Version", b"13"),
                    (b"Sec-WebSocket-Key", nonce),
                ] + request_headers,
            )))
    client.receive_data(
        server.send(
            AcceptConnection(
                extra_headers=accept_headers or [],
                subprotocol=subprotocol,
                extensions=extensions or [],
            )))
    event = client.next_event()
    return event, nonce
Example #7
0
def websocket(request):
    # The underlying socket must be provided by the server. Gunicorn and
    # Werkzeug's dev server are known to support this.
    stream = request.environ.get("werkzeug.socket")

    if stream is None:
        stream = request.environ.get("gunicorn.socket")

    if stream is None:
        raise InternalServerError()

    # Initialize the wsproto connection. Need to recreate the request
    # data that was read by the WSGI server already.
    ws = WSConnection(ConnectionType.SERVER)
    in_data = b"GET %s HTTP/1.1\r\n" % request.path.encode("utf8")

    for header, value in request.headers.items():
        in_data += f"{header}: {value}\r\n".encode()

    in_data += b"\r\n"
    ws.receive_data(in_data)
    running = True

    while True:
        out_data = b""

        for event in ws.events():
            if isinstance(event, WSRequest):
                out_data += ws.send(AcceptConnection())
            elif isinstance(event, CloseConnection):
                out_data += ws.send(event.response())
                running = False
            elif isinstance(event, Ping):
                out_data += ws.send(event.response())
            elif isinstance(event, TextMessage):
                # echo the incoming message back to the client
                if event.data == "quit":
                    out_data += ws.send(
                        CloseConnection(CloseReason.NORMAL_CLOSURE, "bye")
                    )
                    running = False
                else:
                    out_data += ws.send(Message(data=event.data))

        if out_data:
            stream.send(out_data)

        if not running:
            break

        in_data = stream.recv(4096)
        ws.receive_data(in_data)

    # The connection will be closed at this point, but WSGI still
    # requires a response.
    return Response("", status=204)
Example #8
0
def test_handshake_with_subprotocol():
    events = _make_handshake(
        101,
        [
            (b"connection", b"Upgrade"),
            (b"upgrade", b"WebSocket"),
            (b"sec-websocket-protocol", b"one"),
        ],
        subprotocols=["one", "two"],
    )
    assert events == [AcceptConnection(subprotocol="one")]
Example #9
0
    def accept_request(self,
                       extra_headers: Headers = None,
                       sub_protocol: str = None) -> None:
        self._check_ws_headers(extra_headers)
        if sub_protocol is not None and not isinstance(sub_protocol, str):
            raise TypeError('sub_protocol must be a string')

        extra_headers = extra_headers if extra_headers else []
        self._client.sendall(
            self._ws.send(
                AcceptConnection(extra_headers=extra_headers,
                                 subprotocol=sub_protocol)))
Example #10
0
def test_handshake_with_extension():
    extension = FakeExtension(offer_response=True)
    events = _make_handshake(
        101,
        [
            (b"connection", b"Upgrade"),
            (b"upgrade", b"WebSocket"),
            (b"sec-websocket-extensions", b"fake"),
        ],
        extensions=[extension],
    )
    assert events == [AcceptConnection(extensions=[extension])]
Example #11
0
def test_successful_handshake():
    client = H11Handshake(CLIENT)
    server = H11Handshake(SERVER)

    server.receive_data(client.send(Request(host="localhost", target="/")))
    assert isinstance(next(server.events()), Request)

    client.receive_data(server.send(AcceptConnection()))
    assert isinstance(next(client.events()), AcceptConnection)

    assert client.state is ConnectionState.OPEN
    assert server.state is ConnectionState.OPEN
Example #12
0
def handle_connection(stream: socket.socket) -> None:
    """
    Handle a connection.

    The server operates a request/response cycle, so it performs a synchronous
    loop:

    1) Read data from network into wsproto
    2) Get new events and handle them
    3) Send data from wsproto to network

    :param stream: a socket stream
    """
    ws = WSConnection(ConnectionType.SERVER)
    running = True

    while running:
        # 1) Read data from network
        in_data = stream.recv(RECEIVE_BYTES)
        print("Received {} bytes".format(len(in_data)))
        ws.receive_data(in_data)

        # 2) Get new events and handle them
        out_data = b""
        for event in ws.events():
            if isinstance(event, Request):
                # Negotiate new WebSocket connection
                print("Accepting WebSocket upgrade")
                out_data += ws.send(AcceptConnection())
            elif isinstance(event, CloseConnection):
                # Print log message and break out
                print("Connection closed: code={} reason={}".format(
                    event.code, event.reason))
                out_data += ws.send(event.response())
                running = False
            elif isinstance(event, TextMessage):
                # Reverse text and send it back to wsproto
                print("Received request and sending response")
                out_data += ws.send(Message(data=event.data[::-1]))
            elif isinstance(event, Ping):
                # wsproto handles ping events for you by placing a pong frame in
                # the outgoing buffer. You should not call pong() unless you want to
                # send an unsolicited pong frame.
                print("Received ping and sending pong")
                out_data += ws.send(event.response())
            else:
                print(f"Unknown event: {event!r}")

        # 4) Send data from wsproto to network
        print("Sending {} bytes".format(len(out_data)))
        stream.send(out_data)
Example #13
0
    async def init_for_server(cls, stream):
        ws = WSConnection(ConnectionType.SERVER)
        transport = cls(stream, ws)

        # Wait for client to init WebSocket handshake
        event = "Websocket handshake timeout"
        with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT):
            event = await transport._next_ws_event()

        if isinstance(event, Request):
            transport.logger.debug("Accepting WebSocket upgrade")
            await transport._net_send(AcceptConnection())
            return transport

        transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event)
        raise TransportError(f"Unexpected event during WebSocket handshake: {event}")
Example #14
0
    async def upgrade(self, request):
        data = '{} {} HTTP/1.1\r\n'.format(request.method, request.url)
        data += '\r\n'.join(('{}: {}'.format(k, v)
                             for k, v in request.headers.items())) + '\r\n\r\n'

        data = data.encode()

        try:
            self.protocol.receive_data(data)
        except RemoteProtocolError:
            raise HTTPError(HTTPStatus.BAD_REQUEST)
        else:
            event = next(self.protocol.events())
            if not isinstance(event, Request):
                raise HTTPError(HTTPStatus.BAD_REQUEST)
            data = self.protocol.send(AcceptConnection())
            await self.socket.sendall(data)
Example #15
0
    async def init_for_server(  # type: ignore[misc]
        cls: Type["Transport"], stream: Stream, upgrade_request: Optional[H11Request] = None
    ) -> "Transport":
        ws = WSConnection(ConnectionType.SERVER)
        if upgrade_request:
            ws.initiate_upgrade_connection(
                headers=upgrade_request.headers, path=upgrade_request.target
            )
        transport = cls(stream, ws)

        # Wait for client to init WebSocket handshake
        event: Union[str, Event] = "Websocket handshake timeout"
        with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT):
            event = await transport._next_ws_event()
        if isinstance(event, Request):
            transport.logger.debug("Accepting WebSocket upgrade")
            await transport._net_send(AcceptConnection())
            return transport

        transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event)
        raise TransportError(f"Unexpected event during WebSocket handshake: {event}")
Example #16
0
def new_conn(sock: socket.socket) -> None:
    global count
    print("test_server.py received connection {}".format(count))
    count += 1
    ws = WSConnection(SERVER)
    closed = False
    while not closed:
        try:
            data: Optional[bytes] = sock.recv(65535)
        except socket.error:
            data = None

        ws.receive_data(data or None)

        outgoing_data = b""
        for event in ws.events():
            if isinstance(event, Request):
                outgoing_data += ws.send(
                    AcceptConnection(extensions=[PerMessageDeflate()]))
            elif isinstance(event, Message):
                outgoing_data += ws.send(
                    Message(data=event.data,
                            message_finished=event.message_finished))
            elif isinstance(event, Ping):
                outgoing_data += ws.send(event.response())
            elif isinstance(event, CloseConnection):
                closed = True
                if ws.state is not ConnectionState.CLOSED:
                    outgoing_data += ws.send(event.response())

        if not data:
            closed = True

        try:
            sock.sendall(outgoing_data)
        except socket.error:
            closed = True

    sock.close()
Example #17
0
    async def init_for_server(cls, stream, first_request_data=None):
        ws = WSConnection(ConnectionType.SERVER)
        try:
            if first_request_data:
                ws.receive_data(first_request_data)
            transport = cls(stream, ws)

            # Wait for client to init WebSocket handshake
            event = "Websocket handshake timeout"
            with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT):
                event = await transport._next_ws_event()

        except RemoteProtocolError as exc:
            raise TransportError(f"Invalid WebSocket query: {exc}") from exc

        if isinstance(event, Request):
            transport.logger.debug("Accepting WebSocket upgrade")
            await transport._net_send(AcceptConnection())
            return transport

        transport.logger.warning("Unexpected event during WebSocket handshake",
                                 ws_event=event)
        raise TransportError(
            f"Unexpected event during WebSocket handshake: {event}")
Example #18
0
def handle_connection(stream):
    '''
    Handle a connection.

    The server operates a request/response cycle, so it performs a synchronous
    loop:

    1) Read data from network into wsproto
    2) Get next wsproto event
    3) Handle event
    4) Send data from wsproto to network

    :param stream: a socket stream
    '''
    ws = WSConnection(ConnectionType.SERVER)

    # events is a generator that yields websocket event objects. Usually you
    # would say `for event in ws.events()`, but the synchronous nature of this
    # server requires us to use next(event) instead so that we can interleave
    # the network I/O.
    events = ws.events()
    running = True

    while running:
        # 1) Read data from network
        in_data = stream.recv(RECEIVE_BYTES)
        print('Received {} bytes'.format(len(in_data)))
        ws.receive_data(in_data)

        # 2) Get next wsproto event
        try:
            event = next(events)
        except StopIteration:
            print('Client connection dropped unexpectedly')
            return

        # 3) Handle event
        if isinstance(event, Request):
            # Negotiate new WebSocket connection
            print('Accepting WebSocket upgrade')
            out_data = ws.send(AcceptConnection())
        elif isinstance(event, CloseConnection):
            # Print log message and break out
            print('Connection closed: code={}/{} reason={}'.format(
                event.code.value, event.code.name, event.reason))
            out_data = ws.send(event.response())
            running = False
        elif isinstance(event, TextMessage):
            # Reverse text and send it back to wsproto
            print('Received request and sending response')
            out_data = ws.send(Message(data=event.data[::-1]))
        elif isinstance(event, Ping):
            # wsproto handles ping events for you by placing a pong frame in
            # the outgoing buffer. You should not call pong() unless you want to
            # send an unsolicited pong frame.
            print('Received ping and sending pong')
            out_data = ws.send(event.response())
        else:
            print('Unknown event: {!r}'.format(event))

        # 4) Send data from wsproto to network
        print('Sending {} bytes'.format(len(out_data)))
        stream.send(out_data)
Example #19
0
        async def sender(send: Send, port: QueuePort):
            state = 'connecting'

            while websocket.state not in (ConnectionState.CLOSED,
                                          ConnectionState.LOCAL_CLOSING):
                event = await port.pull()
                if event is QueuePort.PORT_CLOSED_SENTINEL: return

                assert isinstance(event, dict)

                if event["type"] == "websocket.send":
                    if "bytes" in event:
                        await send({
                            'type':
                            'http.response.body',
                            'body':
                            websocket.send(BytesMessage(data=event["bytes"])),
                            'more_body':
                            True
                        })
                    elif "text" in event:
                        await send({
                            'type':
                            'http.response.body',
                            'body':
                            websocket.send(TextMessage(data=event["text"])),
                            'more_body':
                            True
                        })

                elif event["type"] == "websocket.close":
                    if state == "connected":
                        code = event.get("code", 1000)
                        await send({
                            'type':
                            'http.response.body',
                            'body':
                            websocket.send(CloseConnection(code=code)),
                            'more_body':
                            False
                        })
                        await port.close({
                            'type': 'websocket.close',
                            'code': code
                        })
                    else:
                        state = "denied"
                        await send({
                            'type': 'http.response.start',
                            'status': 403,
                            'headers': []
                        })
                        await send({
                            'type': 'http.response.body',
                            'body': b'',
                            'more_body': False
                        })
                        await port.close({'type': 'websocket.close'})

                elif event["type"] in ("websocket.http.response.start",
                                       "websocket.http.response.body"):
                    if state == "connected":
                        raise ValueError(
                            "You already accepted a websocket connection.")
                    if state == "connecting" and event[
                            "type"] == "websocket.http.response.body":
                        raise ValueError("You did not start a response.")
                    elif state == "denied" and event[
                            "type"] == "websocket.http.response.start":
                        raise ValueError("You already started a response.")

                    state = "denied"
                    event = event.copy()
                    event["type"] = event["type"][len("websocket."):]
                    await send(event)

                elif event["type"] == "websocket.accept":
                    raw = websocket.send(
                        AcceptConnection(
                            extra_headers=event.get("headers", []),
                            subprotocol=event.get("subprotocol", None)))
                    connection.receive_data(raw)
                    response = connection.next_event()
                    state = "connected"
                    await send({
                        'type': 'http.response.start',
                        'status': response.status_code,
                        'headers': response.headers
                    })
Example #20
0
def test_handshake():
    events = _make_handshake(
        101, [(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket")]
    )
    assert events == [AcceptConnection()]