def test_pull_encrypted_extensions_with_alpn_and_early_data(self): buf = Buffer(data=load("tls_encrypted_extensions_with_alpn_and_early_data.bin")) extensions = pull_encrypted_extensions(buf) self.assertIsNotNone(extensions) self.assertTrue(buf.eof()) self.assertEqual( extensions, EncryptedExtensions( alpn_protocol="hq-20", early_data=True, other_extensions=[ (tls.ExtensionType.SERVER_NAME, b""), ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, SERVER_QUIC_TRANSPORT_PARAMETERS_3, ), ], ), ) # serialize buf = Buffer(116) push_encrypted_extensions(buf, extensions) self.assertTrue(buf.eof())
def test_seek(self): buf = Buffer(data=b"01234567") self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 0) buf.seek(4) self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 4) buf.seek(8) self.assertTrue(buf.eof()) self.assertEqual(buf.tell(), 8)
def test_pull_certificate(self): buf = Buffer(data=load("tls_certificate.bin")) certificate = pull_certificate(buf) self.assertTrue(buf.eof()) self.assertEqual(certificate.request_context, b"") self.assertEqual(certificate.certificates, [(CERTIFICATE_DATA, b"")])
def test_pull_new_session_ticket_with_unknown_extension(self): buf = Buffer(data=load("tls_new_session_ticket_with_unknown_extension.bin")) new_session_ticket = pull_new_session_ticket(buf) self.assertIsNotNone(new_session_ticket) self.assertTrue(buf.eof()) self.assertEqual( new_session_ticket, NewSessionTicket( ticket_lifetime=86400, ticket_age_add=3303452425, ticket_nonce=b"", ticket=binascii.unhexlify( "dbe6f1a77a78c0426bfa607cd0d02b350247d90618704709596beda7e962cc81" ), max_early_data_size=4294967295, other_extensions=[(12345, b"foo")], ), ) # serialize buf = Buffer(100) push_new_session_ticket(buf, new_session_ticket) self.assertEqual( buf.data, load("tls_new_session_ticket_with_unknown_extension.bin") )
def test_pull_client_hello_with_psk(self): buf = Buffer(data=load("tls_client_hello_with_psk.bin")) hello = pull_client_hello(buf) self.assertEqual(hello.early_data, True) self.assertEqual( hello.pre_shared_key, tls.OfferedPsks( identities=[ ( binascii.unhexlify( "fab3dc7d79f35ea53e9adf21150e601591a750b80cde0cd167fef6e0cdbc032a" "c4161fc5c5b66679de49524bd5624c50d71ba3e650780a4bfe402d6a06a00525" "0b5dc52085233b69d0dd13924cc5c713a396784ecafc59f5ea73c1585d79621b" "8a94e4f2291b17427d5185abf4a994fca74ee7a7f993a950c71003fc7cf8" ), 2067156378, ) ], binders=[ binascii.unhexlify( "1788ad43fdff37cfc628f24b6ce7c8c76180705380da17da32811b5bae4e78" "d7aaaf65a9b713872f2bb28818ca1a6b01" ) ], ), ) self.assertTrue(buf.eof()) # serialize buf = Buffer(1000) push_client_hello(buf, hello) self.assertEqual(buf.data, load("tls_client_hello_with_psk.bin"))
def test_pull_server_hello(self): buf = Buffer(data=load("tls_server_hello.bin")) hello = pull_server_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280" ), ) self.assertEqual( hello.session_id, binascii.unhexlify( "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" ), ) self.assertEqual(hello.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384) self.assertEqual(hello.compression_method, tls.CompressionMethod.NULL) self.assertEqual( hello.key_share, ( tls.Group.SECP256R1, binascii.unhexlify( "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189" "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf" "b2" ), ), ) self.assertEqual(hello.pre_shared_key, None) self.assertEqual(hello.supported_version, tls.TLS_VERSION_1_3)
def test_pull_server_hello_with_unknown_extension(self): buf = Buffer(data=load("tls_server_hello_with_unknown_extension.bin")) hello = pull_server_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello, ServerHello( random=binascii.unhexlify( "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280" ), session_id=binascii.unhexlify( "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" ), cipher_suite=tls.CipherSuite.AES_256_GCM_SHA384, compression_method=tls.CompressionMethod.NULL, key_share=( tls.Group.SECP256R1, binascii.unhexlify( "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189" "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf" "b2" ), ), supported_version=tls.TLS_VERSION_1_3, other_extensions=[(12345, b"foo")], ), ) # serialize buf = Buffer(1000) push_server_hello(buf, hello) self.assertEqual(buf.data, load("tls_server_hello_with_unknown_extension.bin"))
def test_pull_certificate_verify(self): buf = Buffer(data=load("tls_certificate_verify.bin")) verify = pull_certificate_verify(buf) self.assertTrue(buf.eof()) self.assertEqual(verify.algorithm, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256) self.assertEqual(verify.signature, CERTIFICATE_VERIFY_SIGNATURE)
def parse_settings(data: bytes) -> Dict[int, int]: buf = Buffer(data=data) settings = [] while not buf.eof(): setting = buf.pull_uint_var() value = buf.pull_uint_var() settings.append((setting, value)) return dict(settings)
def test_seek(self): buf = Buffer(data=b"01234567") self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 0) buf.seek(4) self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 4) buf.seek(8) self.assertTrue(buf.eof()) self.assertEqual(buf.tell(), 8) with self.assertRaises(BufferReadError): buf.seek(-1) self.assertEqual(buf.tell(), 8) with self.assertRaises(BufferReadError): buf.seek(9) self.assertEqual(buf.tell(), 8)
def test_pull_finished(self): buf = Buffer(data=load("tls_finished.bin")) finished = pull_finished(buf) self.assertTrue(buf.eof()) self.assertEqual( finished.verify_data, binascii.unhexlify( "f157923234ff9a4921aadb2e0ec7b1a30fce73fb9ec0c4276f9af268f408ec68" ), )
def parse_settings(data: bytes) -> Dict[int, int]: buf = Buffer(data=data) settings: Dict[int, int] = {} while not buf.eof(): setting = buf.pull_uint_var() value = buf.pull_uint_var() if setting in RESERVED_SETTINGS: raise SettingsError("Setting identifier 0x%x is reserved" % setting) if setting in settings: raise SettingsError("Setting identifier 0x%x is included twice" % setting) settings[setting] = value return dict(settings)
def test_encrypted_extensions(self): data = load("tls_encrypted_extensions.bin") buf = Buffer(data=data) extensions = pull_encrypted_extensions(buf) self.assertIsNotNone(extensions) self.assertTrue(buf.eof()) self.assertEqual( extensions, EncryptedExtensions(other_extensions=[( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, SERVER_QUIC_TRANSPORT_PARAMETERS, )]), ) # serialize buf = Buffer(capacity=100) push_encrypted_extensions(buf, extensions) self.assertEqual(buf.data, data)
def test_pull_version_negotiation(self): buf = Buffer(data=load("version_negotiation.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION) self.assertEqual(header.packet_type, None) self.assertEqual(header.destination_cid, binascii.unhexlify("9aac5a49ba87a849")) self.assertEqual(header.source_cid, binascii.unhexlify("f92f4336fa951ba1")) self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") self.assertEqual(header.rest_length, 8) self.assertEqual(buf.tell(), 23) versions = [] while not buf.eof(): versions.append(buf.pull_uint32()) self.assertEqual(versions, [0x45474716, QuicProtocolVersion.VERSION_1]),
def test_pull_server_hello_with_psk(self): buf = Buffer(data=load("tls_server_hello_with_psk.bin")) hello = pull_server_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "ccbaaf04fc1bd5143b2cc6b97520cf37d91470dbfc8127131a7bf0f941e3a137" ), ) self.assertEqual( hello.session_id, binascii.unhexlify( "9483e7e895d0f4cec17086b0849601c0632662cd764e828f2f892f4c4b7771b0" ), ) self.assertEqual(hello.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384) self.assertEqual(hello.compression_method, tls.CompressionMethod.NULL) self.assertEqual( hello.key_share, ( tls.Group.SECP256R1, binascii.unhexlify( "0485d7cecbebfc548fc657bf51b8e8da842a4056b164a27f7702ca318c16e488" "18b6409593b15c6649d6f459387a53128b164178adc840179aad01d36ce95d62" "76" ), ), ) self.assertEqual(hello.pre_shared_key, 0) self.assertEqual(hello.supported_version, tls.TLS_VERSION_1_3) # serialize buf = Buffer(1000) push_server_hello(buf, hello) self.assertEqual(buf.data, load("tls_server_hello_with_psk.bin"))
def test_pull_client_hello(self): buf = Buffer(data=load("tls_client_hello.bin")) hello = pull_client_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "18b2b23bf3e44b5d52ccfe7aecbc5ff14eadc3d349fabf804d71f165ae76e7d5" ), ) self.assertEqual( hello.session_id, binascii.unhexlify( "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" ), ) self.assertEqual( hello.cipher_suites, [ tls.CipherSuite.AES_256_GCM_SHA384, tls.CipherSuite.AES_128_GCM_SHA256, tls.CipherSuite.CHACHA20_POLY1305_SHA256, ], ) self.assertEqual(hello.compression_methods, [tls.CompressionMethod.NULL]) # extensions self.assertEqual(hello.alpn_protocols, None) self.assertEqual( hello.key_share, [ ( tls.Group.SECP256R1, binascii.unhexlify( "047bfea344467535054263b75def60cffa82405a211b68d1eb8d1d944e67aef8" "93c7665a5473d032cfaf22a73da28eb4aacae0017ed12557b5791f98a1e84f15" "b0" ), ) ], ) self.assertEqual( hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE] ) self.assertEqual(hello.server_name, None) self.assertEqual( hello.signature_algorithms, [ tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA1, ], ) self.assertEqual(hello.supported_groups, [tls.Group.SECP256R1]) self.assertEqual( hello.supported_versions, [ tls.TLS_VERSION_1_3, tls.TLS_VERSION_1_3_DRAFT_28, tls.TLS_VERSION_1_3_DRAFT_27, tls.TLS_VERSION_1_3_DRAFT_26, ], ) self.assertEqual( hello.other_extensions, [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, CLIENT_QUIC_TRANSPORT_PARAMETERS, ) ], )
def parse_max_push_id(data: bytes) -> int: buf = Buffer(data=data) max_push_id = buf.pull_uint_var() assert buf.eof() return max_push_id
def _receive_stream_data( self, stream_id: int, data: bytes, stream_ended: bool ) -> List[Event]: http_events: List[Event] = [] if stream_id in self._stream_buffers: self._stream_buffers[stream_id] += data else: self._stream_buffers[stream_id] = data consumed = 0 buf = Buffer(data=self._stream_buffers[stream_id]) while not buf.eof(): # fetch stream type for unidirectional streams if ( stream_is_unidirectional(stream_id) and stream_id not in self._stream_types ): try: stream_type = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() if stream_type == StreamType.CONTROL: assert self._peer_control_stream_id is None self._peer_control_stream_id = stream_id elif stream_type == StreamType.QPACK_DECODER: assert self._peer_decoder_stream_id is None self._peer_decoder_stream_id = stream_id elif stream_type == StreamType.QPACK_ENCODER: assert self._peer_encoder_stream_id is None self._peer_encoder_stream_id = stream_id self._stream_types[stream_id] = stream_type # fetch next frame try: frame_type = buf.pull_uint_var() frame_length = buf.pull_uint_var() frame_data = buf.pull_bytes(frame_length) except BufferReadError: break consumed = buf.tell() if (stream_id % 4) == 0: # client-initiated bidirectional streams carry requests and responses if frame_type == FrameType.DATA: http_events.append( DataReceived( data=frame_data, stream_id=stream_id, stream_ended=stream_ended and buf.eof(), ) ) elif frame_type == FrameType.HEADERS: control, headers = self._decoder.feed_header(stream_id, frame_data) cls = ResponseReceived if self._is_client else RequestReceived http_events.append( cls( headers=headers, stream_id=stream_id, stream_ended=stream_ended and buf.eof(), ) ) elif stream_id == self._peer_control_stream_id: # unidirectional control stream if frame_type == FrameType.SETTINGS: settings = parse_settings(frame_data) self._encoder.apply_settings( max_table_capacity=settings.get( Setting.QPACK_MAX_TABLE_CAPACITY, 0 ), blocked_streams=settings.get(Setting.QPACK_BLOCKED_STREAMS, 0), ) # remove processed data from buffer self._stream_buffers[stream_id] = self._stream_buffers[stream_id][consumed:] return http_events
def test_pull_client_hello_with_sni(self): buf = Buffer(data=load("tls_client_hello_with_sni.bin")) hello = pull_client_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "987d8934140b0a42cc5545071f3f9f7f61963d7b6404eb674c8dbe513604346b" ), ) self.assertEqual( hello.session_id, binascii.unhexlify( "26b19bdd30dbf751015a3a16e13bd59002dfe420b799d2a5cd5e11b8fa7bcb66" ), ) self.assertEqual( hello.cipher_suites, [ tls.CipherSuite.AES_256_GCM_SHA384, tls.CipherSuite.AES_128_GCM_SHA256, tls.CipherSuite.CHACHA20_POLY1305_SHA256, ], ) self.assertEqual(hello.compression_methods, [tls.CompressionMethod.NULL]) # extensions self.assertEqual(hello.alpn_protocols, None) self.assertEqual( hello.key_share, [ ( tls.Group.SECP256R1, binascii.unhexlify( "04b62d70f907c814cd65d0f73b8b991f06b70c77153f548410a191d2b19764a2" "ecc06065a480efa9e1f10c8da6e737d5bfc04be3f773e20a0c997f51b5621280" "40" ), ) ], ) self.assertEqual( hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE] ) self.assertEqual(hello.server_name, "cloudflare-quic.com") self.assertEqual( hello.signature_algorithms, [ tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA1, ], ) self.assertEqual(hello.supported_groups, [tls.Group.SECP256R1]) self.assertEqual( hello.supported_versions, [ tls.TLS_VERSION_1_3, tls.TLS_VERSION_1_3_DRAFT_28, tls.TLS_VERSION_1_3_DRAFT_27, tls.TLS_VERSION_1_3_DRAFT_26, ], ) self.assertEqual( hello.other_extensions, [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, CLIENT_QUIC_TRANSPORT_PARAMETERS, ) ], ) # serialize buf = Buffer(1000) push_client_hello(buf, hello) self.assertEqual(buf.data, load("tls_client_hello_with_sni.bin"))
def test_pull_client_hello_with_alpn(self): buf = Buffer(data=load("tls_client_hello_with_alpn.bin")) hello = pull_client_hello(buf) self.assertTrue(buf.eof()) self.assertEqual( hello.random, binascii.unhexlify( "ed575c6fbd599c4dfaabd003dca6e860ccdb0e1782c1af02e57bf27cb6479b76" ), ) self.assertEqual(hello.session_id, b"") self.assertEqual( hello.cipher_suites, [ tls.CipherSuite.AES_128_GCM_SHA256, tls.CipherSuite.AES_256_GCM_SHA384, tls.CipherSuite.CHACHA20_POLY1305_SHA256, tls.CipherSuite.EMPTY_RENEGOTIATION_INFO_SCSV, ], ) self.assertEqual(hello.compression_methods, [tls.CompressionMethod.NULL]) # extensions self.assertEqual(hello.alpn_protocols, ["h3-19"]) self.assertEqual(hello.early_data, False) self.assertEqual( hello.key_share, [ ( tls.Group.SECP256R1, binascii.unhexlify( "048842315c437bb0ce2929c816fee4e942ec5cb6db6a6b9bf622680188ebb0d4" "b652e69033f71686aa01cbc79155866e264c9f33f45aa16b0dfa10a222e3a669" "22" ), ) ], ) self.assertEqual( hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE] ) self.assertEqual(hello.server_name, "cloudflare-quic.com") self.assertEqual( hello.signature_algorithms, [ tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, tls.SignatureAlgorithm.ECDSA_SECP384R1_SHA384, tls.SignatureAlgorithm.ECDSA_SECP521R1_SHA512, tls.SignatureAlgorithm.ED25519, tls.SignatureAlgorithm.ED448, tls.SignatureAlgorithm.RSA_PSS_PSS_SHA256, tls.SignatureAlgorithm.RSA_PSS_PSS_SHA384, tls.SignatureAlgorithm.RSA_PSS_PSS_SHA512, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA384, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA512, tls.SignatureAlgorithm.RSA_PKCS1_SHA256, tls.SignatureAlgorithm.RSA_PKCS1_SHA384, tls.SignatureAlgorithm.RSA_PKCS1_SHA512, ], ) self.assertEqual( hello.supported_groups, [ tls.Group.SECP256R1, tls.Group.X25519, tls.Group.SECP384R1, tls.Group.SECP521R1, ], ) self.assertEqual(hello.supported_versions, [tls.TLS_VERSION_1_3]) # serialize buf = Buffer(1000) push_client_hello(buf, hello) self.assertEqual(len(buf.data), len(load("tls_client_hello_with_alpn.bin")))
def _receive_stream_data_uni(self, stream_id: int, data: bytes) -> List[HttpEvent]: http_events: List[HttpEvent] = [] stream = self._stream[stream_id] stream.buffer += data buf = Buffer(data=stream.buffer) consumed = 0 unblocked_streams: Set[int] = set() while not buf.eof(): # fetch stream type for unidirectional streams if stream.stream_type is None: try: stream.stream_type = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() if stream.stream_type == StreamType.CONTROL: assert self._peer_control_stream_id is None self._peer_control_stream_id = stream_id elif stream.stream_type == StreamType.QPACK_DECODER: assert self._peer_decoder_stream_id is None self._peer_decoder_stream_id = stream_id elif stream.stream_type == StreamType.QPACK_ENCODER: assert self._peer_encoder_stream_id is None self._peer_encoder_stream_id = stream_id if stream_id == self._peer_control_stream_id: # fetch next frame try: frame_type = buf.pull_uint_var() frame_length = buf.pull_uint_var() frame_data = buf.pull_bytes(frame_length) except BufferReadError: break consumed = buf.tell() # unidirectional control stream if frame_type == FrameType.SETTINGS: settings = parse_settings(frame_data) encoder = self._encoder.apply_settings( max_table_capacity=settings.get( Setting.QPACK_MAX_TABLE_CAPACITY, 0), blocked_streams=settings.get( Setting.QPACK_BLOCKED_STREAMS, 0), ) self._quic.send_stream_data(self._local_encoder_stream_id, encoder) else: # fetch unframed data data = buf.pull_bytes(buf.capacity - buf.tell()) consumed = buf.tell() if stream_id == self._peer_decoder_stream_id: self._encoder.feed_decoder(data) elif stream_id == self._peer_encoder_stream_id: unblocked_streams.update(self._decoder.feed_encoder(data)) # remove processed data from buffer stream.buffer = stream.buffer[consumed:] # process unblocked streams for stream_id in unblocked_streams: stream = self._stream[stream_id] decoder, headers = self._decoder.resume_header(stream_id) stream.blocked = False cls = ResponseReceived if self._is_client else RequestReceived http_events.append( cls( headers=headers, stream_id=stream_id, stream_ended=stream.ended and not stream.buffer, )) http_events.extend( self._receive_stream_data_bidi(stream_id, b"", stream.ended)) return http_events
def _receive_request_or_push_data(self, stream: H3Stream, data: bytes, stream_ended: bool) -> List[H3Event]: """ Handle data received on a request or push stream. """ http_events: List[H3Event] = [] stream.buffer += data if stream_ended: stream.ended = True if stream.blocked: return http_events # shortcut for DATA frame fragments if (stream.frame_type == FrameType.DATA and stream.frame_size is not None and len(stream.buffer) < stream.frame_size): http_events.append( DataReceived( data=stream.buffer, push_id=stream.push_id, stream_id=stream.stream_id, stream_ended=False, )) stream.frame_size -= len(stream.buffer) stream.buffer = b"" return http_events # handle lone FIN if stream_ended and not stream.buffer: http_events.append( DataReceived( data=b"", push_id=stream.push_id, stream_id=stream.stream_id, stream_ended=True, )) return http_events buf = Buffer(data=stream.buffer) consumed = 0 while not buf.eof(): # fetch next frame header if stream.frame_size is None: try: stream.frame_type = buf.pull_uint_var() stream.frame_size = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() # log frame if (self._quic_logger is not None and stream.frame_type == FrameType.DATA): self._quic_logger.log_event( category="http", event="frame_parsed", data=qlog_encode_data_frame( byte_length=stream.frame_size, stream_id=stream.stream_id), ) # check how much data is available chunk_size = min(stream.frame_size, buf.capacity - consumed) if stream.frame_type != FrameType.DATA and chunk_size < stream.frame_size: break # read available data frame_data = buf.pull_bytes(chunk_size) consumed = buf.tell() # detect end of frame stream.frame_size -= chunk_size if not stream.frame_size: stream.frame_size = None try: http_events.extend( self._handle_request_or_push_frame( frame_type=stream.frame_type, frame_data=frame_data, stream=stream, stream_ended=stream.ended and buf.eof(), )) except pylsqpack.StreamBlocked: stream.blocked = True stream.blocked_frame_size = len(frame_data) break # remove processed data from buffer stream.buffer = stream.buffer[consumed:] return http_events
def _receive_stream_data_uni(self, stream: H3Stream, data: bytes, stream_ended: bool) -> List[H3Event]: http_events: List[H3Event] = [] stream.buffer += data if stream_ended: stream.ended = True buf = Buffer(data=stream.buffer) consumed = 0 unblocked_streams: Set[int] = set() while stream.stream_type == StreamType.PUSH or not buf.eof(): # fetch stream type for unidirectional streams if stream.stream_type is None: try: stream.stream_type = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() # check unicity if stream.stream_type == StreamType.CONTROL: if self._peer_control_stream_id is not None: raise StreamCreationError( "Only one control stream is allowed") self._peer_control_stream_id = stream.stream_id elif stream.stream_type == StreamType.QPACK_DECODER: if self._peer_decoder_stream_id is not None: raise StreamCreationError( "Only one QPACK decoder stream is allowed") self._peer_decoder_stream_id = stream.stream_id elif stream.stream_type == StreamType.QPACK_ENCODER: if self._peer_encoder_stream_id is not None: raise StreamCreationError( "Only one QPACK encoder stream is allowed") self._peer_encoder_stream_id = stream.stream_id if stream.stream_type == StreamType.CONTROL: # fetch next frame try: frame_type = buf.pull_uint_var() frame_length = buf.pull_uint_var() frame_data = buf.pull_bytes(frame_length) except BufferReadError: break consumed = buf.tell() self._handle_control_frame(frame_type, frame_data) elif stream.stream_type == StreamType.PUSH: # fetch push id if stream.push_id is None: try: stream.push_id = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() # remove processed data from buffer stream.buffer = stream.buffer[consumed:] return self._receive_request_or_push_data( stream, b"", stream_ended) elif stream.stream_type == StreamType.QPACK_DECODER: # feed unframed data to decoder data = buf.pull_bytes(buf.capacity - buf.tell()) consumed = buf.tell() try: self._encoder.feed_decoder(data) except pylsqpack.DecoderStreamError as exc: raise QpackDecoderStreamError() from exc self._decoder_bytes_received += len(data) elif stream.stream_type == StreamType.QPACK_ENCODER: # feed unframed data to encoder data = buf.pull_bytes(buf.capacity - buf.tell()) consumed = buf.tell() try: unblocked_streams.update(self._decoder.feed_encoder(data)) except pylsqpack.EncoderStreamError as exc: raise QpackEncoderStreamError() from exc self._encoder_bytes_received += len(data) else: # unknown stream type, discard data buf.seek(buf.capacity) consumed = buf.tell() # remove processed data from buffer stream.buffer = stream.buffer[consumed:] # process unblocked streams for stream_id in unblocked_streams: stream = self._stream[stream_id] # resume headers http_events.extend( self._handle_request_or_push_frame( frame_type=FrameType.HEADERS, frame_data=None, stream=stream, stream_ended=stream.ended and not stream.buffer, )) stream.blocked = False stream.blocked_frame_size = None # resume processing if stream.buffer: http_events.extend( self._receive_request_or_push_data(stream, b"", stream.ended)) return http_events
def _receive_stream_data_bidi(self, stream_id: int, data: bytes, stream_ended: bool) -> List[HttpEvent]: """ Client-initiated bidirectional streams carry requests and responses. """ http_events: List[HttpEvent] = [] stream = self._stream[stream_id] stream.buffer += data if stream_ended: stream.ended = True if stream.blocked: return http_events # shortcut DATA frame bits if (stream.frame_size is not None and stream.frame_type == FrameType.DATA and len(stream.buffer) < stream.frame_size): http_events.append( DataReceived(data=stream.buffer, stream_id=stream_id, stream_ended=False)) stream.frame_size -= len(stream.buffer) stream.buffer = b"" return http_events # some peers (e.g. f5) end the stream with no data if stream_ended and not stream.buffer: http_events.append( DataReceived(data=b"", stream_id=stream_id, stream_ended=True)) return http_events buf = Buffer(data=stream.buffer) consumed = 0 while not buf.eof(): # fetch next frame header if stream.frame_size is None: try: stream.frame_type = buf.pull_uint_var() stream.frame_size = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() # check how much data is available chunk_size = min(stream.frame_size, buf.capacity - consumed) if (stream.frame_type == FrameType.HEADERS and chunk_size < stream.frame_size): break # read available data frame_data = buf.pull_bytes(chunk_size) consumed = buf.tell() # detect end of frame stream.frame_size -= chunk_size if not stream.frame_size: stream.frame_size = None if stream.frame_type == FrameType.DATA and (stream_ended or frame_data): http_events.append( DataReceived( data=frame_data, stream_id=stream_id, stream_ended=stream_ended and buf.eof(), )) elif stream.frame_type == FrameType.HEADERS: try: decoder, headers = self._decoder.feed_header( stream_id, frame_data) except StreamBlocked: stream.blocked = True break self._quic.send_stream_data(self._local_decoder_stream_id, decoder) cls = ResponseReceived if self._is_client else RequestReceived http_events.append( cls( headers=headers, stream_id=stream_id, stream_ended=stream_ended and buf.eof(), )) # remove processed data from buffer stream.buffer = stream.buffer[consumed:] return http_events