def test_round_trip(self): request_headers = Headers() request_key = build_request(request_headers) response_key = check_request(request_headers) self.assertEqual(request_key, response_key) response_headers = Headers() build_response(response_headers, response_key) check_response(response_headers, request_key)
def test_accept_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() client.receive_data( ( f"HTTP/1.1 101 Switching Protocols\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Accept: {ACCEPT}\r\n" f"Date: {DATE}\r\n" f"Server: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) [response] = client.events_received() self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( response.headers, Headers( { "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, "Date": DATE, "Server": USER_AGENT, } ), ) self.assertIsNone(response.body)
def __init__(self, uri: str, health_check_uri: str, cert, token): self._uri = uri self._hc_uri = health_check_uri self._token = token self._extra_headers = Headers() if token is not None: self._extra_headers["token"] = token # Mimics the behavior of the ssl argument when connection to # websockets. If none is specified it will deduce based on the url, # if True it will enforce TLS, and if you want to use self signed # certificates you need to pass an ssl_context with the certificate # loaded. self._cert = cert ssl_context: Optional[Union[bool, ssl.SSLContext]] = None if cert is not None: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(cadata=cert) else: ssl_context = True if self._uri.startswith("wss") else None self._ssl_context: Optional[Union[bool, ssl.SSLContext]] = ssl_context self._loop = asyncio.new_event_loop() self._connection: asyncio.Task = self._loop.create_task( self._connect()) self._ws: Optional[WebSocketClientProtocol] = None self._loop_thread = threading.Thread(target=self._loop.run_forever) self._loop_thread.start()
def __init__(self, host, port, protocol="wss", cert=None, token=None): self._base_uri = f"{protocol}://{host}:{port}" self._client_uri = f"{self._base_uri}/client" self._result_uri = f"{self._base_uri}/result" self._token = token self._extra_headers = Headers() if token is not None: self._extra_headers["token"] = token # Mimics the behavior of the ssl argument when connection to # websockets. If none is specified it will deduce based on the url, # if True it will enforce TLS, and if you want to use self signed # certificates you need to pass an ssl_context with the certificate # loaded. self._cert = cert if cert is not None: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(cadata=cert) else: ssl_context = True if protocol == "wss" else None self._ssl_context = ssl_context self._loop = None self._incoming = None self._receive_future = None self._id = str(uuid.uuid1()).split("-")[0]
def test_reject_response(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("ws://example.com/test") client.connect() client.receive_data( ( f"HTTP/1.1 404 Not Found\r\n" f"Date: {DATE}\r\n" f"Server: {USER_AGENT}\r\n" f"Content-Length: 13\r\n" f"Content-Type: text/plain; charset=utf-8\r\n" f"Connection: close\r\n" f"\r\n" f"Sorry folks.\n" ).encode(), ) [response] = client.events_received() self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( response.headers, Headers( { "Date": DATE, "Server": USER_AGENT, "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", "Connection": "close", } ), ) self.assertEqual(response.body, b"Sorry folks.\n")
def __init__( self, url, token=None, cert=None, max_retries=10, timeout_multiplier=5 ): if url is None: raise ValueError("url was None") self.url = url self.token = token self._extra_headers = Headers() if token is not None: self._extra_headers["token"] = token # Mimics the behavior of the ssl argument when connection to # websockets. If none is specified it will deduce based on the url, # if True it will enforce TLS, and if you want to use self signed # certificates you need to pass an ssl_context with the certificate # loaded. if cert is not None: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(cadata=cert) else: ssl_context = True if url.startswith("wss") else None self._ssl_context = ssl_context self._max_retries = max_retries self._timeout_multiplier = timeout_multiplier self.websocket = None self.loop = asyncio.new_event_loop()
def test_serialize(self): # Example from the protocol overview in RFC 6455 request = Request( "/chat", Headers( [ ("Host", "server.example.com"), ("Upgrade", "websocket"), ("Connection", "Upgrade"), ("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ=="), ("Origin", "http://example.com"), ("Sec-WebSocket-Protocol", "chat, superchat"), ("Sec-WebSocket-Version", "13"), ] ), ) self.assertEqual( request.serialize(), b"GET /chat HTTP/1.1\r\n" b"Host: server.example.com\r\n" b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" b"Origin: http://example.com\r\n" b"Sec-WebSocket-Protocol: chat, superchat\r\n" b"Sec-WebSocket-Version: 13\r\n" b"\r\n", )
def test_extra_headers(self): for extra_headers in [ Headers({"X-Spam": "Eggs"}), {"X-Spam": "Eggs"}, [("X-Spam", "Eggs")], lambda path, headers: Headers({"X-Spam": "Eggs"}), lambda path, headers: {"X-Spam": "Eggs"}, lambda path, headers: [("X-Spam", "Eggs")], ]: with self.subTest(extra_headers=extra_headers): server = ServerConnection(extra_headers=extra_headers) request = self.make_request() response = server.accept(request) self.assertEqual(response.status_code, 101) self.assertEqual(response.headers["X-Spam"], "Eggs")
def test_connect_request(self): server = ServerConnection() server.receive_data( ( f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {KEY}\r\n" f"Sec-WebSocket-Version: 13\r\n" f"User-Agent: {USER_AGENT}\r\n" f"\r\n" ).encode(), ) [request] = server.events_received() self.assertEqual(request.path, "/test") self.assertEqual( request.headers, Headers( { "Host": "example.com", "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", "User-Agent": USER_AGENT, } ), )
async def execute_queue_async( # pylint: disable=too-many-arguments self, ws_uri: str, ee_id: str, pool_sema: threading.BoundedSemaphore, evaluators: Callable[..., Any], cert: Optional[Union[str, bytes]] = None, token: Optional[str] = None, ) -> None: if evaluators is None: evaluators = [] if cert is not None: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(cadata=cert) else: ssl_context = True if ws_uri.startswith("wss") else None headers = Headers() if token is not None: headers["token"] = token try: await JobQueue._publish_changes(ee_id, self._differ.snapshot(), ws_uri, ssl_context, headers) while True: self.launch_jobs(pool_sema) await asyncio.sleep(1) for func in evaluators: func() await JobQueue._publish_changes( ee_id, self.changes_after_transition(), ws_uri, ssl_context, headers) if self.stopped: raise asyncio.CancelledError if not self.is_active(): break except asyncio.CancelledError: logger.debug("queue cancelled, stopping jobs...") await self.stop_jobs_async() logger.debug("jobs stopped, re-raising CancelledError") raise except Exception: logger.exception( "unexpected exception in queue", exc_info=True, ) await self.stop_jobs_async() logger.debug("jobs stopped, re-raising exception") raise self.assert_complete() self._differ.transition(self.job_list) await JobQueue._publish_changes(ee_id, self._differ.snapshot(), ws_uri, ssl_context, headers)
def assertValidResponseHeaders(self, key="CSIRmL8dWYxeAdr/XpEHRw=="): """ Provide response headers for modification. Assert that the transformation kept them valid. """ headers = Headers() build_response(headers, key) yield headers check_response(headers, key)
def assertValidRequestHeaders(self): """ Provide request headers for modification. Assert that the transformation kept them valid. """ headers = Headers() build_request(headers) yield headers check_request(headers)
async def send_loop(q): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(cadata=self.cert) async with websockets.connect( uri, ssl=ssl_context, extra_headers=Headers(token=self.token)) as websocket: while True: msg = await q.get() if msg == "stop": return await websocket.send(msg)
async def execute_queue_async(self, ws_uri, ee_id, pool_sema, evaluators, cert=None, token=None): if cert is not None: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(cadata=cert) else: ssl_context = True if ws_uri.startswith("wss") else None headers = Headers() if token is not None: headers["token"] = token async with websockets.connect(ws_uri, ssl=ssl_context, extra_headers=headers) as websocket: await JobQueue._publish_changes(ee_id, self._differ.snapshot(), websocket) try: while True: self.launch_jobs(pool_sema) await asyncio.sleep(1) if evaluators is not None: for func in evaluators: func() await JobQueue._publish_changes( ee_id, self.changes_after_transition(), websocket) if not self.is_active() or self.stopped: break except asyncio.CancelledError: if self.stopped: logger.debug( "observed that the queue had stopped after cancellation, stopping jobs..." ) self.stop_jobs() logger.debug("jobs now stopped (after cancellation)") raise if self.stopped: logger.debug( "observed that the queue had stopped, stopping jobs...") await self.stop_jobs_async() logger.debug("jobs now stopped") self.assert_complete() self._differ.transition(self.job_list) await JobQueue._publish_changes(ee_id, self._differ.snapshot(), websocket)
def make_request(self): return Request( path="/test", headers=Headers({ "Host": "example.com", "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", "User-Agent": USER_AGENT, }), )
def assertInvalidResponseHeaders(self, exc_type, key="CSIRmL8dWYxeAdr/XpEHRw=="): """ Provide response headers for modification. Assert that the transformation made them invalid. """ headers = Headers() build_response(headers, key) yield headers assert issubclass(exc_type, InvalidHandshake) with self.assertRaises(exc_type): check_response(headers, key)
def test_extra_headers(self): for extra_headers in [ Headers({"X-Spam": "Eggs"}), {"X-Spam": "Eggs"}, [("X-Spam", "Eggs")], ]: with self.subTest(extra_headers=extra_headers): client = ClientConnection( "wss://example.com/", extra_headers=extra_headers ) request = client.connect() self.assertEqual(request.headers["X-Spam"], "Eggs")
def assertInvalidRequestHeaders(self, exc_type): """ Provide request headers for modification. Assert that the transformation made them invalid. """ headers = Headers() build_request(headers) yield headers assert issubclass(exc_type, InvalidHandshake) with self.assertRaises(exc_type): check_request(headers)
def make_accept_response(self, client): request = client.connect() return Response( status_code=101, reason_phrase="Switching Protocols", headers=Headers({ "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": accept_key(request.headers["Sec-WebSocket-Key"]), }), )
def test_serialize_with_body(self): response = Response( 200, "OK", Headers([("Content-Length", "13"), ("Content-Type", "text/plain")]), b"Hello world!\n", ) self.assertEqual( response.serialize(), b"HTTP/1.1 200 OK\r\n" b"Content-Length: 13\r\n" b"Content-Type: text/plain\r\n" b"\r\n" b"Hello world!\n", )
def test_connect_request(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): client = ClientConnection("wss://example.com/test") request = client.connect() self.assertEqual(request.path, "/test") self.assertEqual( request.headers, Headers({ "Host": "example.com", "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Key": KEY, "Sec-WebSocket-Version": "13", "User-Agent": USER_AGENT, }), )
def __init__( self, uri: str, health_check_uri: str, cert: Union[str, bytes, os.PathLike], token: str, ) -> None: self._uri = uri self._hc_uri = health_check_uri self._token = token self._extra_headers = Headers() if token is not None: self._extra_headers["token"] = token # Mimics the behavior of the ssl argument when connection to # websockets. If none is specified it will deduce based on the url, # if True it will enforce TLS, and if you want to use self signed # certificates you need to pass an ssl_context with the certificate # loaded. self._cert = cert ssl_context: Optional[Union[bool, ssl.SSLContext]] = None if cert is not None: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(cadata=cert) else: ssl_context = True if self._uri.startswith("wss") else None self._ssl_context: Optional[Union[bool, ssl.SSLContext]] = ssl_context self._loop = asyncio.new_event_loop() self._connection: asyncio.Task = self._loop.create_task( self._connect()) self._ws: Optional[WebSocketClientProtocol] = None self._loop_thread = threading.Thread(target=self._loop.run_forever) self._loop_thread.start() # Ensure the async thread either makes a connection, or raises the _connect() # exception before returning. Not before a connection has been made, can this # class be used safely. while not self._connection.done(): time.sleep(0.1) try: self._connection.result() except Exception: self.stop() raise
def test_accept_response(self): server = ServerConnection() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.accept(self.make_request()) self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 101) self.assertEqual(response.reason_phrase, "Switching Protocols") self.assertEqual( response.headers, Headers({ "Upgrade": "websocket", "Connection": "Upgrade", "Sec-WebSocket-Accept": ACCEPT, "Date": DATE, "Server": USER_AGENT, }), ) self.assertIsNone(response.body)
def test_reject_response(self): server = ServerConnection() with unittest.mock.patch("email.utils.formatdate", return_value=DATE): response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n") self.assertIsInstance(response, Response) self.assertEqual(response.status_code, 404) self.assertEqual(response.reason_phrase, "Not Found") self.assertEqual( response.headers, Headers({ "Date": DATE, "Server": USER_AGENT, "Content-Length": "13", "Content-Type": "text/plain; charset=utf-8", "Connection": "close", }), ) self.assertEqual(response.body, b"Sorry folks.\n")
def test_serialize(self): # Example from the protocol overview in RFC 6455 response = Response( 101, "Switching Protocols", Headers([ ("Upgrade", "websocket"), ("Connection", "Upgrade"), ("Sec-WebSocket-Accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="), ("Sec-WebSocket-Protocol", "chat"), ]), ) self.assertEqual( response.serialize(), b"HTTP/1.1 101 Switching Protocols\r\n" b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" b"Sec-WebSocket-Protocol: chat\r\n" b"\r\n", )
def _mock_websocket_connect_exception(*args, **kwargs): raise InvalidStatusCode(404, Headers())
def test_str(self): for exception, exception_str in [ # fmt: off ( WebSocketException("something went wrong"), "something went wrong", ), ( ConnectionClosed(1000, ""), "code = 1000 (OK), no reason", ), (ConnectionClosed(1006, None), "code = 1006 (connection closed abnormally [internal]), no reason" ), (ConnectionClosed(3000, None), "code = 3000 (registered), no reason"), (ConnectionClosed(4000, None), "code = 4000 (private use), no reason"), (ConnectionClosedError(1016, None), "code = 1016 (unknown), no reason"), ( ConnectionClosedOK(1001, "bye"), "code = 1001 (going away), reason = bye", ), ( InvalidHandshake("invalid request"), "invalid request", ), ( SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), ( InvalidMessage("malformed HTTP message"), "malformed HTTP message", ), ( InvalidHeader("Name"), "missing Name header", ), ( InvalidHeader("Name", None), "missing Name header", ), ( InvalidHeader("Name", ""), "empty Name header", ), ( InvalidHeader("Name", "Value"), "invalid Name header: Value", ), ( InvalidHeaderFormat("Sec-WebSocket-Protocol", "expected token", "a=|", 3), "invalid Sec-WebSocket-Protocol header: " "expected token at 3 in a=|", ), ( InvalidHeaderValue("Sec-WebSocket-Version", "42"), "invalid Sec-WebSocket-Version header: 42", ), ( InvalidOrigin("http://bad.origin"), "invalid Origin header: http://bad.origin", ), ( InvalidUpgrade("Upgrade"), "missing Upgrade header", ), ( InvalidUpgrade("Connection", "websocket"), "invalid Connection header: websocket", ), ( InvalidStatusCode(403), "server rejected WebSocket connection: HTTP 403", ), ( NegotiationError("unsupported subprotocol: spam"), "unsupported subprotocol: spam", ), ( DuplicateParameter("a"), "duplicate parameter: a", ), ( InvalidParameterName("|"), "invalid parameter name: |", ), ( InvalidParameterValue("a", None), "missing value for parameter a", ), ( InvalidParameterValue("a", ""), "empty value for parameter a", ), ( InvalidParameterValue("a", "|"), "invalid value for parameter a: |", ), ( AbortHandshake(200, Headers(), b"OK\n"), "HTTP 200, 0 headers, 3 bytes", ), ( RedirectHandshake("wss://example.com"), "redirect to wss://example.com", ), ( InvalidState("WebSocket connection isn't established yet"), "WebSocket connection isn't established yet", ), ( InvalidURI("|"), "| isn't a valid URI", ), ( PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), "payload length exceeds limit: 2 > 1 bytes", ), ( ProtocolError("invalid opcode: 7"), "invalid opcode: 7", ), # fmt: on ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str)
def test_str(self): for exception, exception_str in [ # fmt: off ( WebSocketException("something went wrong"), "something went wrong", ), ( ConnectionClosed(Close(1000, ""), Close(1000, ""), True), "received 1000 (OK); then sent 1000 (OK)", ), ( ConnectionClosed(Close(1001, "Bye!"), Close(1001, "Bye!"), False), "sent 1001 (going away) Bye!; then received 1001 (going away) Bye!", ), ( ConnectionClosed(Close(1000, "race"), Close(1000, "cond"), True), "received 1000 (OK) race; then sent 1000 (OK) cond", ), ( ConnectionClosed(Close(1000, "cond"), Close(1000, "race"), False), "sent 1000 (OK) race; then received 1000 (OK) cond", ), ( ConnectionClosed(None, Close(1009, ""), None), "sent 1009 (message too big); no close frame received", ), ( ConnectionClosed(Close(1002, ""), None, None), "received 1002 (protocol error); no close frame sent", ), ( ConnectionClosedOK(Close(1000, ""), Close(1000, ""), True), "received 1000 (OK); then sent 1000 (OK)", ), ( ConnectionClosedError(None, None, None), "no close frame received or sent" ), ( InvalidHandshake("invalid request"), "invalid request", ), ( SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), ( InvalidMessage("malformed HTTP message"), "malformed HTTP message", ), ( InvalidHeader("Name"), "missing Name header", ), ( InvalidHeader("Name", None), "missing Name header", ), ( InvalidHeader("Name", ""), "empty Name header", ), ( InvalidHeader("Name", "Value"), "invalid Name header: Value", ), ( InvalidHeaderFormat( "Sec-WebSocket-Protocol", "expected token", "a=|", 3 ), "invalid Sec-WebSocket-Protocol header: " "expected token at 3 in a=|", ), ( InvalidHeaderValue("Sec-WebSocket-Version", "42"), "invalid Sec-WebSocket-Version header: 42", ), ( InvalidOrigin("http://bad.origin"), "invalid Origin header: http://bad.origin", ), ( InvalidUpgrade("Upgrade"), "missing Upgrade header", ), ( InvalidUpgrade("Connection", "websocket"), "invalid Connection header: websocket", ), ( InvalidStatus(Response(401, "Unauthorized", Headers())), "server rejected WebSocket connection: HTTP 401", ), ( InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", ), ( NegotiationError("unsupported subprotocol: spam"), "unsupported subprotocol: spam", ), ( DuplicateParameter("a"), "duplicate parameter: a", ), ( InvalidParameterName("|"), "invalid parameter name: |", ), ( InvalidParameterValue("a", None), "missing value for parameter a", ), ( InvalidParameterValue("a", ""), "empty value for parameter a", ), ( InvalidParameterValue("a", "|"), "invalid value for parameter a: |", ), ( AbortHandshake(200, Headers(), b"OK\n"), "HTTP 200, 0 headers, 3 bytes", ), ( RedirectHandshake("wss://example.com"), "redirect to wss://example.com", ), ( InvalidState("WebSocket connection isn't established yet"), "WebSocket connection isn't established yet", ), ( InvalidURI("|"), "| isn't a valid URI", ), ( PayloadTooBig("payload length exceeds limit: 2 > 1 bytes"), "payload length exceeds limit: 2 > 1 bytes", ), ( ProtocolError("invalid opcode: 7"), "invalid opcode: 7", ), # fmt: on ]: with self.subTest(exception=exception): self.assertEqual(str(exception), exception_str)