async def start_server(self, sock: anyio.abc.SocketStream, filter=None): # pylint: disable=W0622 """Start a server WS connection on this socket. Filter: an async callable that gets passed the initial Request. It may return an AcceptConnection message, a bool, or a string (the subprotocol to use). Returns: the Request message. """ assert self._scope is None self._scope = True self._sock = sock self._connection = WSConnection(ConnectionType.SERVER) try: event = await self._next_event() if not isinstance(event, Request): raise ConnectionError("Failed to establish a connection", event) msg = None if filter is not None: msg = await filter(event) if not msg: msg = RejectConnection() elif msg is True: msg = None elif isinstance(msg, str): msg = AcceptConnection(subprotocol=msg) if not msg: msg = AcceptConnection(subprotocol=event.subprotocols[0]) data = self._connection.send(msg) await self._sock.send_all(data) if not isinstance(msg, AcceptConnection): raise ConnectionError("Not accepted", msg) finally: self._scope = None
def __init__(self, ctx, handshake_flow): super().__init__(ctx) self.handshake_flow = handshake_flow self.flow: WebSocketFlow = None self.client_frame_buffer = [] self.server_frame_buffer = [] self.connections: dict[object, WSConnection] = {} client_extensions = [] server_extensions = [] if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers: if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']: client_extensions = [PerMessageDeflate()] server_extensions = [PerMessageDeflate()] self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER) self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT) if client_extensions: client_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions']) if server_extensions: server_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions']) request = Request(extensions=client_extensions, host=handshake_flow.request.host, target=handshake_flow.request.path) data = self.connections[self.server_conn].send(request) self.connections[self.client_conn].receive_data(data) event = next(self.connections[self.client_conn].events()) assert isinstance(event, events.Request) data = self.connections[self.client_conn].send(AcceptConnection(extensions=server_extensions)) self.connections[self.server_conn].receive_data(data) assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection)
def test_handshake_extra_accept_headers() -> None: events = _make_handshake( 101, [(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket"), (b"X-Foo", b"bar")], ) assert events == [AcceptConnection(extra_headers=[(b"x-foo", b"bar")])]
def _handle_events(self): keep_going = True out_data = b'' try: for event in self.ws.events(): if isinstance(event, Request): out_data += self.ws.send(AcceptConnection()) elif isinstance(event, CloseConnection): if self.is_server: out_data += self.ws.send(event.response()) self.event.set() keep_going = False elif isinstance(event, Ping): out_data += self.ws.send(event.response()) elif isinstance(event, TextMessage): self.input_buffer.append(event.data) self.event.set() elif isinstance(event, BytesMessage): self.input_buffer.append(event.data) self.event.set() if out_data: self.sock.send(out_data) return keep_going except: return False
async def asgi_send(self, message: dict) -> None: """Called by the ASGI instance to send a message.""" if message["type"] == "websocket.accept" and self.state == ASGIWebsocketState.HANDSHAKE: await self.asend(AcceptConnection(extensions=[PerMessageDeflate()])) self.state = ASGIWebsocketState.CONNECTED self.config.access_logger.access( self.scope, {"status": 101, "headers": []}, time() - self.start_time ) elif ( message["type"] == "websocket.http.response.start" and self.state == ASGIWebsocketState.HANDSHAKE ): self.response = message self.config.access_logger.access(self.scope, self.response, time() - self.start_time) elif message["type"] == "websocket.http.response.body" and self.state in { ASGIWebsocketState.HANDSHAKE, ASGIWebsocketState.RESPONSE, }: await self._asgi_send_rejection(message) elif message["type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED: data: Union[bytes, str] if message.get("bytes") is not None: await self.asend(BytesMessage(data=bytes(message["bytes"]))) elif not isinstance(message["text"], str): raise TypeError(f"{message['text']} should be a str") else: await self.asend(TextMessage(data=message["text"])) elif message["type"] == "websocket.close" and self.state == ASGIWebsocketState.HANDSHAKE: await self.send_http_error(403) self.state = ASGIWebsocketState.HTTPCLOSED elif message["type"] == "websocket.close": await self.asend(CloseConnection(code=int(message["code"]))) self.state = ASGIWebsocketState.CLOSED else: raise UnexpectedMessage(self.state, message["type"])
def _make_handshake( request_headers: Headers, accept_headers: Optional[Headers] = None, subprotocol: Optional[str] = None, extensions: Optional[List[Extension]] = None, ) -> Tuple[h11.InformationalResponse, bytes]: client = h11.Connection(h11.CLIENT) server = WSConnection(SERVER) nonce = generate_nonce() server.receive_data( client.send( h11.Request( method="GET", target="/", headers=[ (b"Host", b"localhost"), (b"Connection", b"Keep-Alive, Upgrade"), (b"Upgrade", b"WebSocket"), (b"Sec-WebSocket-Version", b"13"), (b"Sec-WebSocket-Key", nonce), ] + request_headers, ))) client.receive_data( server.send( AcceptConnection( extra_headers=accept_headers or [], subprotocol=subprotocol, extensions=extensions or [], ))) event = client.next_event() return event, nonce
def websocket(request): # The underlying socket must be provided by the server. Gunicorn and # Werkzeug's dev server are known to support this. stream = request.environ.get("werkzeug.socket") if stream is None: stream = request.environ.get("gunicorn.socket") if stream is None: raise InternalServerError() # Initialize the wsproto connection. Need to recreate the request # data that was read by the WSGI server already. ws = WSConnection(ConnectionType.SERVER) in_data = b"GET %s HTTP/1.1\r\n" % request.path.encode("utf8") for header, value in request.headers.items(): in_data += f"{header}: {value}\r\n".encode() in_data += b"\r\n" ws.receive_data(in_data) running = True while True: out_data = b"" for event in ws.events(): if isinstance(event, WSRequest): out_data += ws.send(AcceptConnection()) elif isinstance(event, CloseConnection): out_data += ws.send(event.response()) running = False elif isinstance(event, Ping): out_data += ws.send(event.response()) elif isinstance(event, TextMessage): # echo the incoming message back to the client if event.data == "quit": out_data += ws.send( CloseConnection(CloseReason.NORMAL_CLOSURE, "bye") ) running = False else: out_data += ws.send(Message(data=event.data)) if out_data: stream.send(out_data) if not running: break in_data = stream.recv(4096) ws.receive_data(in_data) # The connection will be closed at this point, but WSGI still # requires a response. return Response("", status=204)
def test_handshake_with_subprotocol(): events = _make_handshake( 101, [ (b"connection", b"Upgrade"), (b"upgrade", b"WebSocket"), (b"sec-websocket-protocol", b"one"), ], subprotocols=["one", "two"], ) assert events == [AcceptConnection(subprotocol="one")]
def accept_request(self, extra_headers: Headers = None, sub_protocol: str = None) -> None: self._check_ws_headers(extra_headers) if sub_protocol is not None and not isinstance(sub_protocol, str): raise TypeError('sub_protocol must be a string') extra_headers = extra_headers if extra_headers else [] self._client.sendall( self._ws.send( AcceptConnection(extra_headers=extra_headers, subprotocol=sub_protocol)))
def test_handshake_with_extension(): extension = FakeExtension(offer_response=True) events = _make_handshake( 101, [ (b"connection", b"Upgrade"), (b"upgrade", b"WebSocket"), (b"sec-websocket-extensions", b"fake"), ], extensions=[extension], ) assert events == [AcceptConnection(extensions=[extension])]
def test_successful_handshake(): client = H11Handshake(CLIENT) server = H11Handshake(SERVER) server.receive_data(client.send(Request(host="localhost", target="/"))) assert isinstance(next(server.events()), Request) client.receive_data(server.send(AcceptConnection())) assert isinstance(next(client.events()), AcceptConnection) assert client.state is ConnectionState.OPEN assert server.state is ConnectionState.OPEN
def handle_connection(stream: socket.socket) -> None: """ Handle a connection. The server operates a request/response cycle, so it performs a synchronous loop: 1) Read data from network into wsproto 2) Get new events and handle them 3) Send data from wsproto to network :param stream: a socket stream """ ws = WSConnection(ConnectionType.SERVER) running = True while running: # 1) Read data from network in_data = stream.recv(RECEIVE_BYTES) print("Received {} bytes".format(len(in_data))) ws.receive_data(in_data) # 2) Get new events and handle them out_data = b"" for event in ws.events(): if isinstance(event, Request): # Negotiate new WebSocket connection print("Accepting WebSocket upgrade") out_data += ws.send(AcceptConnection()) elif isinstance(event, CloseConnection): # Print log message and break out print("Connection closed: code={} reason={}".format( event.code, event.reason)) out_data += ws.send(event.response()) running = False elif isinstance(event, TextMessage): # Reverse text and send it back to wsproto print("Received request and sending response") out_data += ws.send(Message(data=event.data[::-1])) elif isinstance(event, Ping): # wsproto handles ping events for you by placing a pong frame in # the outgoing buffer. You should not call pong() unless you want to # send an unsolicited pong frame. print("Received ping and sending pong") out_data += ws.send(event.response()) else: print(f"Unknown event: {event!r}") # 4) Send data from wsproto to network print("Sending {} bytes".format(len(out_data))) stream.send(out_data)
async def init_for_server(cls, stream): ws = WSConnection(ConnectionType.SERVER) transport = cls(stream, ws) # Wait for client to init WebSocket handshake event = "Websocket handshake timeout" with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT): event = await transport._next_ws_event() if isinstance(event, Request): transport.logger.debug("Accepting WebSocket upgrade") await transport._net_send(AcceptConnection()) return transport transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event) raise TransportError(f"Unexpected event during WebSocket handshake: {event}")
async def upgrade(self, request): data = '{} {} HTTP/1.1\r\n'.format(request.method, request.url) data += '\r\n'.join(('{}: {}'.format(k, v) for k, v in request.headers.items())) + '\r\n\r\n' data = data.encode() try: self.protocol.receive_data(data) except RemoteProtocolError: raise HTTPError(HTTPStatus.BAD_REQUEST) else: event = next(self.protocol.events()) if not isinstance(event, Request): raise HTTPError(HTTPStatus.BAD_REQUEST) data = self.protocol.send(AcceptConnection()) await self.socket.sendall(data)
async def init_for_server( # type: ignore[misc] cls: Type["Transport"], stream: Stream, upgrade_request: Optional[H11Request] = None ) -> "Transport": ws = WSConnection(ConnectionType.SERVER) if upgrade_request: ws.initiate_upgrade_connection( headers=upgrade_request.headers, path=upgrade_request.target ) transport = cls(stream, ws) # Wait for client to init WebSocket handshake event: Union[str, Event] = "Websocket handshake timeout" with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT): event = await transport._next_ws_event() if isinstance(event, Request): transport.logger.debug("Accepting WebSocket upgrade") await transport._net_send(AcceptConnection()) return transport transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event) raise TransportError(f"Unexpected event during WebSocket handshake: {event}")
def new_conn(sock: socket.socket) -> None: global count print("test_server.py received connection {}".format(count)) count += 1 ws = WSConnection(SERVER) closed = False while not closed: try: data: Optional[bytes] = sock.recv(65535) except socket.error: data = None ws.receive_data(data or None) outgoing_data = b"" for event in ws.events(): if isinstance(event, Request): outgoing_data += ws.send( AcceptConnection(extensions=[PerMessageDeflate()])) elif isinstance(event, Message): outgoing_data += ws.send( Message(data=event.data, message_finished=event.message_finished)) elif isinstance(event, Ping): outgoing_data += ws.send(event.response()) elif isinstance(event, CloseConnection): closed = True if ws.state is not ConnectionState.CLOSED: outgoing_data += ws.send(event.response()) if not data: closed = True try: sock.sendall(outgoing_data) except socket.error: closed = True sock.close()
async def init_for_server(cls, stream, first_request_data=None): ws = WSConnection(ConnectionType.SERVER) try: if first_request_data: ws.receive_data(first_request_data) transport = cls(stream, ws) # Wait for client to init WebSocket handshake event = "Websocket handshake timeout" with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT): event = await transport._next_ws_event() except RemoteProtocolError as exc: raise TransportError(f"Invalid WebSocket query: {exc}") from exc if isinstance(event, Request): transport.logger.debug("Accepting WebSocket upgrade") await transport._net_send(AcceptConnection()) return transport transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event) raise TransportError( f"Unexpected event during WebSocket handshake: {event}")
def handle_connection(stream): ''' Handle a connection. The server operates a request/response cycle, so it performs a synchronous loop: 1) Read data from network into wsproto 2) Get next wsproto event 3) Handle event 4) Send data from wsproto to network :param stream: a socket stream ''' ws = WSConnection(ConnectionType.SERVER) # events is a generator that yields websocket event objects. Usually you # would say `for event in ws.events()`, but the synchronous nature of this # server requires us to use next(event) instead so that we can interleave # the network I/O. events = ws.events() running = True while running: # 1) Read data from network in_data = stream.recv(RECEIVE_BYTES) print('Received {} bytes'.format(len(in_data))) ws.receive_data(in_data) # 2) Get next wsproto event try: event = next(events) except StopIteration: print('Client connection dropped unexpectedly') return # 3) Handle event if isinstance(event, Request): # Negotiate new WebSocket connection print('Accepting WebSocket upgrade') out_data = ws.send(AcceptConnection()) elif isinstance(event, CloseConnection): # Print log message and break out print('Connection closed: code={}/{} reason={}'.format( event.code.value, event.code.name, event.reason)) out_data = ws.send(event.response()) running = False elif isinstance(event, TextMessage): # Reverse text and send it back to wsproto print('Received request and sending response') out_data = ws.send(Message(data=event.data[::-1])) elif isinstance(event, Ping): # wsproto handles ping events for you by placing a pong frame in # the outgoing buffer. You should not call pong() unless you want to # send an unsolicited pong frame. print('Received ping and sending pong') out_data = ws.send(event.response()) else: print('Unknown event: {!r}'.format(event)) # 4) Send data from wsproto to network print('Sending {} bytes'.format(len(out_data))) stream.send(out_data)
async def sender(send: Send, port: QueuePort): state = 'connecting' while websocket.state not in (ConnectionState.CLOSED, ConnectionState.LOCAL_CLOSING): event = await port.pull() if event is QueuePort.PORT_CLOSED_SENTINEL: return assert isinstance(event, dict) if event["type"] == "websocket.send": if "bytes" in event: await send({ 'type': 'http.response.body', 'body': websocket.send(BytesMessage(data=event["bytes"])), 'more_body': True }) elif "text" in event: await send({ 'type': 'http.response.body', 'body': websocket.send(TextMessage(data=event["text"])), 'more_body': True }) elif event["type"] == "websocket.close": if state == "connected": code = event.get("code", 1000) await send({ 'type': 'http.response.body', 'body': websocket.send(CloseConnection(code=code)), 'more_body': False }) await port.close({ 'type': 'websocket.close', 'code': code }) else: state = "denied" await send({ 'type': 'http.response.start', 'status': 403, 'headers': [] }) await send({ 'type': 'http.response.body', 'body': b'', 'more_body': False }) await port.close({'type': 'websocket.close'}) elif event["type"] in ("websocket.http.response.start", "websocket.http.response.body"): if state == "connected": raise ValueError( "You already accepted a websocket connection.") if state == "connecting" and event[ "type"] == "websocket.http.response.body": raise ValueError("You did not start a response.") elif state == "denied" and event[ "type"] == "websocket.http.response.start": raise ValueError("You already started a response.") state = "denied" event = event.copy() event["type"] = event["type"][len("websocket."):] await send(event) elif event["type"] == "websocket.accept": raw = websocket.send( AcceptConnection( extra_headers=event.get("headers", []), subprotocol=event.get("subprotocol", None))) connection.receive_data(raw) response = connection.next_event() state = "connected" await send({ 'type': 'http.response.start', 'status': response.status_code, 'headers': response.headers })
def test_handshake(): events = _make_handshake( 101, [(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket")] ) assert events == [AcceptConnection()]