def _make_handshake(
    response_status,
    response_headers,
    subprotocols=None,
    extensions=None,
    auto_accept_key=True,
):
    client = WSConnection(CLIENT)
    server = h11.Connection(h11.SERVER)
    server.receive_data(
        client.send(
            Request(
                host="localhost",
                target="/",
                subprotocols=subprotocols or [],
                extensions=extensions or [],
            )
        )
    )
    request = server.next_event()
    if auto_accept_key:
        full_request_headers = normed_header_dict(request.headers)
        response_headers.append(
            (
                b"Sec-WebSocket-Accept",
                generate_accept_token(full_request_headers[b"sec-websocket-key"]),
            )
        )
    response = h11.InformationalResponse(
        status_code=response_status, headers=response_headers
    )
    client.receive_data(server.send(response))

    return list(client.events())
Exemple #2
0
    def accept(
        self, subprotocol: Optional[str]
    ) -> Tuple[int, List[Tuple[bytes, bytes]], Connection]:
        headers = []
        if subprotocol is not None:
            if subprotocol not in self.subprotocols:
                raise Exception("Invalid Subprotocol")
            else:
                headers.append((b"sec-websocket-protocol", subprotocol.encode()))

        extensions = [PerMessageDeflate()]
        accepts = None
        if False and self.extensions is not None:
            accepts = server_extensions_handshake(self.extensions, extensions)

        if accepts:
            headers.append((b"sec-websocket-extensions", accepts))

        if self.key is not None:
            headers.append((b"sec-websocket-accept", generate_accept_token(self.key)))

        status_code = 200
        if self.http_version == "1.1":
            headers.extend([(b"upgrade", b"WebSocket"), (b"connection", b"Upgrade")])
            status_code = 101

        return status_code, headers, Connection(ConnectionType.SERVER, extensions)
Exemple #3
0
def test_connection_send_state() -> None:
    client = WSConnection(CLIENT)
    assert client.state is ConnectionState.CONNECTING

    server = h11.Connection(h11.SERVER)
    server.receive_data(client.send(Request(
        host="localhost",
        target="/",
    )))
    headers = normed_header_dict(server.next_event().headers)
    response = h11.InformationalResponse(
        status_code=101,
        headers=[
            (b"connection", b"Upgrade"),
            (b"upgrade", b"WebSocket"),
            (
                b"Sec-WebSocket-Accept",
                generate_accept_token(headers[b"sec-websocket-key"]),
            ),
        ],
    )
    client.receive_data(server.send(response))
    assert len(list(client.events())) == 1
    assert client.state is ConnectionState.OPEN  # type: ignore # https://github.com/python/mypy/issues/9005

    with pytest.raises(LocalProtocolError) as excinfo:
        client.send(Request(host="localhost", target="/"))

    client.receive_data(b"foobar")
    assert len(list(client.events())) == 1
Exemple #4
0
def test_handshake() -> None:
    response, nonce = _make_handshake([])

    response.headers = sorted(response.headers)  # For test determinism
    assert response == h11.InformationalResponse(
        status_code=101,
        headers=[
            (b"connection", b"Upgrade"),
            (b"sec-websocket-accept", generate_accept_token(nonce)),
            (b"upgrade", b"WebSocket"),
        ],
    )
Exemple #5
0
    def accept(
        self,
        subprotocol: Optional[str],
        additional_headers: Iterable[Tuple[bytes, bytes]],
    ) -> Tuple[int, List[Tuple[bytes, bytes]], Connection]:
        headers = []
        if subprotocol is not None:
            if subprotocol not in self.subprotocols:
                raise Exception("Invalid Subprotocol")
            else:
                headers.append(
                    (b"sec-websocket-protocol", subprotocol.encode()))

        extensions: List[Extension] = [PerMessageDeflate()]
        accepts = None
        if self.extensions is not None:
            accepts = server_extensions_handshake(self.extensions, extensions)

        if accepts:
            headers.append((b"sec-websocket-extensions", accepts))

        if self.key is not None:
            headers.append(
                (b"sec-websocket-accept", generate_accept_token(self.key)))

        status_code = 200
        if self.http_version == "1.1":
            headers.extend([(b"upgrade", b"WebSocket"),
                            (b"connection", b"Upgrade")])
            status_code = 101

        for name, value in additional_headers:
            if b"sec-websocket-protocol" == name or name.startswith(b":"):
                raise Exception(f"Invalid additional header, {name.decode()}")

            headers.append((name, value))

        return status_code, headers, Connection(ConnectionType.SERVER,
                                                extensions)