Пример #1
0
    def accept(
        self, subprotocol: Optional[str]
    ) -> Tuple[int, List[Tuple[bytes, bytes]], Connection]:
        headers = []
        if subprotocol is not None:
            if subprotocol not in self.subprotocols:
                raise Exception("Invalid Subprotocol")
            else:
                headers.append((b"sec-websocket-protocol", subprotocol.encode()))

        extensions = [PerMessageDeflate()]
        accepts = None
        if False and self.extensions is not None:
            accepts = server_extensions_handshake(self.extensions, extensions)

        if accepts:
            headers.append((b"sec-websocket-extensions", accepts))

        if self.key is not None:
            headers.append((b"sec-websocket-accept", generate_accept_token(self.key)))

        status_code = 200
        if self.http_version == "1.1":
            headers.extend([(b"upgrade", b"WebSocket"), (b"connection", b"Upgrade")])
            status_code = 101

        return status_code, headers, Connection(ConnectionType.SERVER, extensions)
Пример #2
0
    def accept(
        self,
        subprotocol: Optional[str],
        additional_headers: Iterable[Tuple[bytes, bytes]],
    ) -> Tuple[int, List[Tuple[bytes, bytes]], Connection]:
        headers = []
        if subprotocol is not None:
            if subprotocol not in self.subprotocols:
                raise Exception("Invalid Subprotocol")
            else:
                headers.append(
                    (b"sec-websocket-protocol", subprotocol.encode()))

        extensions: List[Extension] = [PerMessageDeflate()]
        accepts = None
        if self.extensions is not None:
            accepts = server_extensions_handshake(self.extensions, extensions)

        if accepts:
            headers.append((b"sec-websocket-extensions", accepts))

        if self.key is not None:
            headers.append(
                (b"sec-websocket-accept", generate_accept_token(self.key)))

        status_code = 200
        if self.http_version == "1.1":
            headers.extend([(b"upgrade", b"WebSocket"),
                            (b"connection", b"Upgrade")])
            status_code = 101

        for name, value in additional_headers:
            if b"sec-websocket-protocol" == name or name.startswith(b":"):
                raise Exception(f"Invalid additional header, {name.decode()}")

            headers.append((name, value))

        return status_code, headers, Connection(ConnectionType.SERVER,
                                                extensions)
Пример #3
0
 async def asgi_send(self, message: dict) -> None:
     if message[
             "type"] == "websocket.accept" and self.state == ASGIWebsocketState.HANDSHAKE:
         self.state = ASGIWebsocketState.CONNECTED
         extensions: List[str] = []
         for name, value in self.scope["headers"]:
             if name == b"sec-websocket-extensions":
                 extensions = split_comma_header(value)
         supported_extensions = [wsproto.extensions.PerMessageDeflate()]
         accepts = server_extensions_handshake(extensions,
                                               supported_extensions)
         headers = [(b":status", b"200")]
         headers.extend(
             build_and_validate_headers(message.get("headers", [])))
         raise_if_subprotocol_present(headers)
         if message.get("subprotocol") is not None:
             headers.append((b"sec-websocket-protocol",
                             message["subprotocol"].encode()))
         if accepts:
             headers.append((b"sec-websocket-extensions", accepts))
         await self.asend(Response(headers))
         self.connection = wsproto.connection.Connection(
             wsproto.connection.ConnectionType.SERVER, supported_extensions)
         self.config.access_logger.access(self.scope, {
             "status": 200,
             "headers": []
         },
                                          time() - self.start_time)
     elif (message["type"] == "websocket.http.response.start"
           and self.state == ASGIWebsocketState.HANDSHAKE):
         self.response = message
         self.config.access_logger.access(self.scope, self.response,
                                          time() - self.start_time)
     elif message[
             "type"] == "websocket.http.response.body" and self.state in {
                 ASGIWebsocketState.HANDSHAKE,
                 ASGIWebsocketState.RESPONSE,
             }:
         await self._asgi_send_rejection(message)
     elif message[
             "type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED:
         event: wsproto.events.Event
         if message.get("bytes") is not None:
             event = wsproto.events.BytesMessage(
                 data=bytes(message["bytes"]))
         elif not isinstance(message["text"], str):
             raise TypeError(f"{message['text']} should be a str")
         else:
             event = wsproto.events.TextMessage(data=message["text"])
         await self.asend(Data(self.connection.send(event)))
     elif message[
             "type"] == "websocket.close" and self.state == ASGIWebsocketState.HANDSHAKE:
         await self.send_http_error(403)
         self.state = ASGIWebsocketState.HTTPCLOSED
     elif message["type"] == "websocket.close":
         data = self.connection.send(
             wsproto.events.CloseConnection(code=int(message["code"])))
         await self.asend(Data(data))
         self.state = ASGIWebsocketState.CLOSED
     else:
         raise UnexpectedMessage(self.state, message["type"])