def handle_todo_post_save(sender, instance, created, **kwargs): if not hasattr(sender, 'APP_PORT'): return conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) conn.connect(('localhost', int(sender.APP_PORT))) ws = WSConnection(ConnectionType.CLIENT) print("hello") net_send( ws.send(Request( host=f"localhost:{sender.APP_PORT}", target=f"ws/todos")), conn ) net_recv(ws, conn) handle_events(ws) net_send(ws.send(Message(data=str(instance.pk))), conn) net_recv(ws, conn) handle_events(ws) net_send(ws.send(CloseConnection(code=1000)), conn) net_recv(ws, conn) conn.shutdown(socket.SHUT_WR) net_recv(ws, conn) del sender.APP_PORT
def __init__( self, app: Type[ASGIFramework], loop: asyncio.AbstractEventLoop, config: Config, transport: asyncio.BaseTransport, *, upgrade_request: Optional[h11.Request] = None, ) -> None: super().__init__(loop, config, transport, "wsproto") self.stop_keep_alive_timeout() self.app = app self.connection = WSConnection(ConnectionType.SERVER) self.app_queue: asyncio.Queue = asyncio.Queue() self.response: Optional[dict] = None self.scope: Optional[dict] = None self.state = ASGIWebsocketState.HANDSHAKE self.task: Optional[asyncio.Future] = None self.buffer = WebsocketBuffer(self.config.websocket_max_message_size) if upgrade_request is not None: self.connection.initiate_upgrade_connection( upgrade_request.headers, upgrade_request.target) self.handle_events()
def test_upgrade_request() -> None: server = WSConnection(SERVER) server.initiate_upgrade_connection( [ (b"Host", b"localhost"), (b"Connection", b"Keep-Alive, Upgrade"), (b"Upgrade", b"WebSocket"), (b"Sec-WebSocket-Version", b"13"), (b"Sec-WebSocket-Key", generate_nonce()), (b"X-Foo", b"bar"), ], "/", ) event = next(server.events()) event = cast(Request, event) assert event.extensions == [] assert event.host == "localhost" assert event.subprotocols == [] assert event.target == "/" headers = normed_header_dict(event.extra_headers) assert b"host" not in headers assert b"sec-websocket-extensions" not in headers assert b"sec-websocket-protocol" not in headers assert headers[b"connection"] == b"Keep-Alive, Upgrade" assert headers[b"sec-websocket-version"] == b"13" assert headers[b"upgrade"] == b"WebSocket" assert headers[b"x-foo"] == b"bar"
def _make_handshake( request_headers: Headers, accept_headers: Optional[Headers] = None, subprotocol: Optional[str] = None, extensions: Optional[List[Extension]] = None, ) -> Tuple[h11.InformationalResponse, bytes]: client = h11.Connection(h11.CLIENT) server = WSConnection(SERVER) nonce = generate_nonce() server.receive_data( client.send( h11.Request( method="GET", target="/", headers=[ (b"Host", b"localhost"), (b"Connection", b"Keep-Alive, Upgrade"), (b"Upgrade", b"WebSocket"), (b"Sec-WebSocket-Version", b"13"), (b"Sec-WebSocket-Key", nonce), ] + request_headers, ))) client.receive_data( server.send( AcceptConnection( extra_headers=accept_headers or [], subprotocol=subprotocol, extensions=extensions or [], ))) event = client.next_event() return event, nonce
async def start_client(self, sock: anyio.abc.SocketStream, addr, path: str, headers: Optional[List] = None, subprotocols: Optional[List[str]] = None): """Start a client WS connection on this socket. Returns: the AcceptConnection message. """ self._sock = sock self._connection = WSConnection(ConnectionType.CLIENT) if headers is None: headers = [] if subprotocols is None: subprotocols = [] data = self._connection.send( Request( host=addr[0], target=path, extra_headers=headers, subprotocols=subprotocols)) await self._sock.send_all(data) assert self._scope is None self._scope = True try: event = await self._next_event() if not isinstance(event, AcceptConnection): raise ConnectionError("Failed to establish a connection", event) return event finally: self._scope = None
async def start_server(self, sock: anyio.abc.SocketStream, filter=None): # pylint: disable=W0622 """Start a server WS connection on this socket. Filter: an async callable that gets passed the initial Request. It may return an AcceptConnection message, a bool, or a string (the subprotocol to use). Returns: the Request message. """ assert self._scope is None self._scope = True self._sock = sock self._connection = WSConnection(ConnectionType.SERVER) try: event = await self._next_event() if not isinstance(event, Request): raise ConnectionError("Failed to establish a connection", event) msg = None if filter is not None: msg = await filter(event) if not msg: msg = RejectConnection() elif msg is True: msg = None elif isinstance(msg, str): msg = AcceptConnection(subprotocol=msg) if not msg: msg = AcceptConnection(subprotocol=event.subprotocols[0]) data = self._connection.send(msg) await self._sock.send_all(data) if not isinstance(msg, AcceptConnection): raise ConnectionError("Not accepted", msg) finally: self._scope = None
class Websocket(WebsocketPrototype): """Server-side websocket running a handler parallel to the I/O. """ def __init__(self, socket): super().__init__() self.socket = socket self.protocol = WSConnection(ConnectionType.SERVER) async def upgrade(self, request): data = '{} {} HTTP/1.1\r\n'.format(request.method, request.url) data += '\r\n'.join(('{}: {}'.format(k, v) for k, v in request.headers.items())) + '\r\n\r\n' data = data.encode() try: self.protocol.receive_data(data) except RemoteProtocolError: raise HTTPError(HTTPStatus.BAD_REQUEST) else: event = next(self.protocol.events()) if not isinstance(event, Request): raise HTTPError(HTTPStatus.BAD_REQUEST) data = self.protocol.send(AcceptConnection()) await self.socket.sendall(data)
def __init__(self, ctx, handshake_flow): super().__init__(ctx) self.handshake_flow = handshake_flow self.flow: WebSocketFlow = None self.client_frame_buffer = [] self.server_frame_buffer = [] self.connections: dict[object, WSConnection] = {} client_extensions = [] server_extensions = [] if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers: if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']: client_extensions = [PerMessageDeflate()] server_extensions = [PerMessageDeflate()] self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER) self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT) if client_extensions: client_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions']) if server_extensions: server_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions']) request = Request(extensions=client_extensions, host=handshake_flow.request.host, target=handshake_flow.request.path) data = self.connections[self.server_conn].send(request) self.connections[self.client_conn].receive_data(data) event = next(self.connections[self.client_conn].events()) assert isinstance(event, events.Request) data = self.connections[self.client_conn].send(AcceptConnection(extensions=server_extensions)) self.connections[self.server_conn].receive_data(data) assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection)
def _make_connection_request(request_headers, method="GET"): # type: (List[Tuple[str, str]]) -> Request client = h11.Connection(h11.CLIENT) server = WSConnection(SERVER) server.receive_data( client.send(h11.Request(method=method, target="/", headers=request_headers)) ) return next(server.events())
def _make_connection_request(request_headers: Headers, method: str = "GET") -> Request: client = h11.Connection(h11.CLIENT) server = WSConnection(SERVER) server.receive_data( client.send( h11.Request(method=method, target="/", headers=request_headers))) return next(server.events()) # type: ignore
def __init__(self, host: str, port: int): self._check_init_arguments(host, port) self._host = host self._port = port self._server: StreamServer = None self._ws = WSConnection(ConnectionType.SERVER) self._client: socket = None # client socket provided by the StreamServer self._running = True
def __init__(self, path: str, *, framework: ASGIFramework = echo_framework) -> None: self.client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() self.server = WebsocketServer(framework, Config(), server_stream) self.connection = WSConnection(ConnectionType.CLIENT) self.server.connection.receive_data( self.connection.send(Request(target=path, host="hypercorn")))
def net_recv(ws: WSConnection, conn: socket.socket) -> None: """ Read pending data from network into websocket. """ in_data = conn.recv(RECEIVE_BYTES) if not in_data: # A receive of zero bytes indicates the TCP socket has been closed. We # need to pass None to wsproto to update its internal state. print("Received 0 bytes (connection closed)") ws.receive_data(None) else: print("Received {} bytes".format(len(in_data))) ws.receive_data(in_data)
def _establish_websocket_handshake(self, host: str, path: str, headers: Headers, extensions: List[str], sub_protocols: List[str]) -> None: self._ws = WSConnection(ConnectionType.CLIENT) headers = headers if headers is not None else [] extensions = extensions if extensions is not None else [] sub_protocols = sub_protocols if sub_protocols is not None else [] request = Request(host=host, target=path, extra_headers=headers, extensions=extensions, subprotocols=sub_protocols) self._sock.sendall(self._ws.send(request))
def __init__( self, path: str, event_loop: asyncio.AbstractEventLoop, *, framework: Type[ASGIFramework] = EchoFramework, ) -> None: self.transport = MockTransport() self.server = WebsocketServer( # type: ignore framework, event_loop, Config(), self.transport) self.connection = WSConnection(ConnectionType.CLIENT) self.server.data_received( self.connection.send(Request(target=path, host="hypercorn")))
def _make_handshake_rejection(status_code, body=None): client = WSConnection(CLIENT) server = h11.Connection(h11.SERVER) server.receive_data(client.send(Request(host="localhost", target="/"))) headers = [] if body is not None: headers.append(("Content-Length", str(len(body)))) client.receive_data( server.send(h11.Response(status_code=status_code, headers=headers)) ) if body is not None: client.receive_data(server.send(h11.Data(data=body))) client.receive_data(server.send(h11.EndOfMessage())) return list(client.events())
async def init_for_client( cls: Type["Transport"], stream: Stream, host: str, keepalive: Optional[int] = None ) -> "Transport": ws = WSConnection(ConnectionType.CLIENT) transport = cls(stream, ws, keepalive) # Because this is a client WebSocket, we need to initiate the connection # handshake by sending a Request event. await transport._net_send( Request( host=host, target=TRANSPORT_TARGET, extra_headers=[(b"User-Agent", USER_AGENT.encode())], ) ) # Get handshake answer event = await transport._next_ws_event() if isinstance(event, AcceptConnection): transport.logger.debug("WebSocket negotiation complete", ws_event=event) else: transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event) reason = f"Unexpected event during WebSocket handshake: {event}" raise TransportError(reason) return transport
def __init__(self, stream: Stream, ws: WSConnection, keepalive: Optional[int] = None): self.stream = stream self.ws = ws self.keepalive = keepalive self.conn_id = uuid4().hex self.logger = logger.bind(conn_id=self.conn_id) self._ws_events = ws.events()
class Websocket(WebsocketPrototype): def __init__(self): super().__init__() self.socket = curio.socket.socket(curio.socket.AF_INET, curio.socket.SOCK_STREAM) self.protocol = WSConnection(ConnectionType.CLIENT) async def connect(self, path, host, port): await self.socket.connect((host, port)) request = Request(host=f'{host}:{port}', target=path) await self.socket.sendall(self.protocol.send(request)) upgrade_response = await self.socket.recv(8096) self.protocol.receive_data(upgrade_response) event = next(self.protocol.events()) if not isinstance(event, AcceptConnection): raise Exception('Websocket handshake failed.')
def handle_events(ws: WSConnection) -> None: for event in ws.events(): if isinstance(event, AcceptConnection): print("WebSocket negotiation complete") elif isinstance(event, TextMessage): print("Received message: {}".format(event.data)) else: raise Exception("Do not know how to handle event: " + str(event))
def handle_connection(stream: socket.socket) -> None: """ Handle a connection. The server operates a request/response cycle, so it performs a synchronous loop: 1) Read data from network into wsproto 2) Get new events and handle them 3) Send data from wsproto to network :param stream: a socket stream """ ws = WSConnection(ConnectionType.SERVER) running = True while running: # 1) Read data from network in_data = stream.recv(RECEIVE_BYTES) print("Received {} bytes".format(len(in_data))) ws.receive_data(in_data) # 2) Get new events and handle them out_data = b"" for event in ws.events(): if isinstance(event, Request): # Negotiate new WebSocket connection print("Accepting WebSocket upgrade") out_data += ws.send(AcceptConnection()) elif isinstance(event, CloseConnection): # Print log message and break out print("Connection closed: code={} reason={}".format( event.code, event.reason)) out_data += ws.send(event.response()) running = False elif isinstance(event, TextMessage): # Reverse text and send it back to wsproto print("Received request and sending response") out_data += ws.send(Message(data=event.data[::-1])) elif isinstance(event, Ping): # wsproto handles ping events for you by placing a pong frame in # the outgoing buffer. You should not call pong() unless you want to # send an unsolicited pong frame. print("Received ping and sending pong") out_data += ws.send(event.response()) else: print(f"Unknown event: {event!r}") # 4) Send data from wsproto to network print("Sending {} bytes".format(len(out_data))) stream.send(out_data)
def update_reports(server, agent): uri = urlparse(server + '/updateReports?agent=%s' % agent) connection = WSConnection(CLIENT) sock = socket.socket() sock.connect((uri.hostname, uri.port or 80)) sock.sendall( connection.send( Request(host=uri.netloc, target='%s?%s' % (uri.path, uri.query)))) closed = False while not closed: data = sock.recv(65535) connection.receive_data(data) for event in connection.events(): if isinstance(event, AcceptConnection): sock.sendall( connection.send( CloseConnection(code=CloseReason.NORMAL_CLOSURE))) try: sock.close() except CONNECTION_EXCEPTIONS: pass finally: closed = True
def get_case_count(server): uri = urlparse(server + '/getCaseCount') connection = WSConnection(CLIENT) sock = socket.socket() sock.connect((uri.hostname, uri.port or 80)) sock.sendall(connection.send(Request(host=uri.netloc, target=uri.path))) case_count = None while case_count is None: data = sock.recv(65535) connection.receive_data(data) data = "" out_data = b"" for event in connection.events(): if isinstance(event, TextMessage): data += event.data if event.message_finished: case_count = json.loads(data) out_data += connection.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE)) try: sock.sendall(out_data) except CONNECTION_EXCEPTIONS: break sock.close() return case_count
def __init__(self, sock=None, connection_type=None, receive_bytes=4096, thread_class=threading.Thread, event_class=threading.Event): self.sock = sock self.receive_bytes = receive_bytes self.input_buffer = [] self.event = event_class() self.connected = False self.is_server = (connection_type == ConnectionType.SERVER) self.ws = WSConnection(connection_type) self.handshake() self.thread = thread_class(target=self._thread) self.thread.start() self.event.wait() self.event.clear()
async def init_for_server( # type: ignore[misc] cls: Type["Transport"], stream: Stream, upgrade_request: Optional[H11Request] = None ) -> "Transport": ws = WSConnection(ConnectionType.SERVER) if upgrade_request: ws.initiate_upgrade_connection( headers=upgrade_request.headers, path=upgrade_request.target ) transport = cls(stream, ws) # Wait for client to init WebSocket handshake event: Union[str, Event] = "Websocket handshake timeout" with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT): event = await transport._next_ws_event() if isinstance(event, Request): transport.logger.debug("Accepting WebSocket upgrade") await transport._net_send(AcceptConnection()) return transport transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event) raise TransportError(f"Unexpected event during WebSocket handshake: {event}")
class MockWebsocketConnection: def __init__(self, path: str, *, framework: ASGIFramework = echo_framework) -> None: self.client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() self.server = WebsocketServer(framework, Config(), server_stream) self.connection = WSConnection(ConnectionType.CLIENT) self.server.connection.receive_data( self.connection.send(Request(target=path, host="hypercorn"))) async def send(self, data: AnyStr) -> None: await self.client_stream.send_all( self.connection.send(Message(data=data))) await trio.sleep(0) # Allow the server to respond async def receive(self) -> List[Message]: data = await self.client_stream.receive_some(2**16) self.connection.receive_data(data) return [event for event in self.connection.events()] async def close(self) -> None: await self.client_stream.send_all( self.connection.send(CloseConnection(code=1000)))
class MockWebsocketConnection: def __init__( self, path: str, event_loop: asyncio.AbstractEventLoop, *, framework: Type[ASGIFramework] = EchoFramework, ) -> None: self.transport = MockTransport() self.server = WebsocketServer( # type: ignore framework, event_loop, Config(), self.transport) self.connection = WSConnection(ConnectionType.CLIENT) self.server.data_received( self.connection.send(Request(target=path, host="hypercorn"))) async def send(self, data: AnyStr) -> None: self.server.data_received(self.connection.send(Message(data=data))) await asyncio.sleep(0) # Allow the server to respond async def receive(self) -> List[Message]: await self.transport.updated.wait() self.connection.receive_data(self.transport.data) self.transport.clear() return [event for event in self.connection.events()] def close(self) -> None: self.server.data_received( self.connection.send(CloseConnection(code=1000)))
def _make_handshake( response_status, response_headers, subprotocols=None, extensions=None, auto_accept_key=True, ): client = WSConnection(CLIENT) server = h11.Connection(h11.SERVER) server.receive_data( client.send( Request( host="localhost", target="/", subprotocols=subprotocols or [], extensions=extensions or [], ) ) ) request = server.next_event() if auto_accept_key: full_request_headers = normed_header_dict(request.headers) response_headers.append( ( b"Sec-WebSocket-Accept", generate_accept_token(full_request_headers[b"sec-websocket-key"]), ) ) response = h11.InformationalResponse( status_code=response_status, headers=response_headers ) client.receive_data(server.send(response)) return list(client.events())
def __init__( self, app: ASGIFramework, config: Config, stream: trio.abc.Stream, *, upgrade_request: Optional[h11.Request] = None, ) -> None: super().__init__(stream, "wsproto") self.app = app self.config = config self.connection = WSConnection(ConnectionType.SERVER) self.response: Optional[dict] = None self.scope: Optional[dict] = None self.send_lock = trio.Lock() self.state = ASGIWebsocketState.HANDSHAKE self.buffer = WebsocketBuffer(self.config.websocket_max_message_size) self.app_send_channel, self.app_receive_channel = trio.open_memory_channel( 10) if upgrade_request is not None: self.connection.initiate_upgrade_connection( upgrade_request.headers, upgrade_request.target)
async def init_for_server(cls, stream, first_request_data=None): ws = WSConnection(ConnectionType.SERVER) try: if first_request_data: ws.receive_data(first_request_data) transport = cls(stream, ws) # Wait for client to init WebSocket handshake event = "Websocket handshake timeout" with trio.move_on_after(WEBSOCKET_HANDSHAKE_TIMEOUT): event = await transport._next_ws_event() except RemoteProtocolError as exc: raise TransportError(f"Invalid WebSocket query: {exc}") from exc if isinstance(event, Request): transport.logger.debug("Accepting WebSocket upgrade") await transport._net_send(AcceptConnection()) return transport transport.logger.warning("Unexpected event during WebSocket handshake", ws_event=event) raise TransportError( f"Unexpected event during WebSocket handshake: {event}")