def test_seek(self): buf = Buffer(data=b"01234567") self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 0) buf.seek(4) self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 4) buf.seek(8) self.assertTrue(buf.eof()) self.assertEqual(buf.tell(), 8)
def test_pull_retry(self): buf = Buffer(data=load("retry.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25) self.assertEqual(header.packet_type, PACKET_TYPE_RETRY) self.assertEqual(header.destination_cid, binascii.unhexlify("e9d146d8d14cb28e")) self.assertEqual( header.source_cid, binascii.unhexlify("0b0a205a648fcf82d85f128b67bbe08053e6"), ) self.assertEqual( header.token, binascii.unhexlify( "44397a35d698393c134b08a932737859f446d3aadd00ed81540c8d8de172" "906d3e7a111b503f9729b8928e7528f9a86a4581f9ebb4cb3b53c283661e" "8530741a99192ee56914c5626998ec0f"), ) self.assertEqual( header.integrity_tag, binascii.unhexlify("e1a3c80c797ea401c08fc9c342a2d90d")) self.assertEqual(header.rest_length, 0) self.assertEqual(buf.tell(), 125) # check integrity self.assertEqual( get_retry_integrity_tag( buf.data_slice(0, 109), binascii.unhexlify("fbbd219b7363b64b"), ), header.integrity_tag, )
def roundtrip(self, data, value): buf = Buffer(data=data) self.assertEqual(pull_uint_var(buf), value) self.assertEqual(buf.tell(), len(data)) buf = Buffer(capacity=8) push_uint_var(buf, value) self.assertEqual(buf.data, data)
def test_seek(self): buf = Buffer(data=b"01234567") self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 0) buf.seek(4) self.assertFalse(buf.eof()) self.assertEqual(buf.tell(), 4) buf.seek(8) self.assertTrue(buf.eof()) self.assertEqual(buf.tell(), 8) with self.assertRaises(BufferReadError): buf.seek(-1) self.assertEqual(buf.tell(), 8) with self.assertRaises(BufferReadError): buf.seek(9) self.assertEqual(buf.tell(), 8)
def _receive_datagram(self, data: bytes) -> List[H3Event]: """ Handle a datagram. """ buf = Buffer(data=data) try: flow_id = buf.pull_uint_var() except BufferReadError: raise ProtocolError("Could not parse flow ID") return [DatagramReceived(data=data[buf.tell():], flow_id=flow_id)]
def test_pull_initial_server(self): buf = Buffer(data=load("initial_server.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.DRAFT_22) self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL) self.assertEqual(header.destination_cid, b"") self.assertEqual(header.source_cid, binascii.unhexlify("195c68344e28d479")) self.assertEqual(header.original_destination_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.rest_length, 184) self.assertEqual(buf.tell(), 18)
def test_pull_short_header(self): buf = Buffer(data=load("short_header.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertFalse(header.is_long_header) self.assertEqual(header.version, None) self.assertEqual(header.packet_type, 0x50) self.assertEqual(header.destination_cid, binascii.unhexlify("f45aa7b59c0e1ad6")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.original_destination_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.rest_length, 12) self.assertEqual(buf.tell(), 9)
def test_pull_version_negotiation(self): buf = Buffer(data=load("version_negotiation.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION) self.assertEqual(header.packet_type, None) self.assertEqual(header.destination_cid, binascii.unhexlify("9aac5a49ba87a849")) self.assertEqual(header.source_cid, binascii.unhexlify("f92f4336fa951ba1")) self.assertEqual(header.original_destination_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.rest_length, 8) self.assertEqual(buf.tell(), 23)
def test_pull_initial_client(self): buf = Buffer(data=load("initial_client.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.DRAFT_17) self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL) self.assertEqual(header.destination_cid, binascii.unhexlify("90ed1e1c7b04b5d3")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.original_destination_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.rest_length, 1263) self.assertEqual(buf.tell(), 17)
def test_pull_initial_client(self): buf = Buffer(data=load("initial_client.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25) self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL) self.assertEqual(header.destination_cid, binascii.unhexlify("858b39368b8e3c6e")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") self.assertEqual(header.rest_length, 1262) self.assertEqual(buf.tell(), 18)
def test_pull_version_negotiation(self): buf = Buffer(data=load("version_negotiation.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION) self.assertEqual(header.packet_type, None) self.assertEqual(header.destination_cid, binascii.unhexlify("dae1889b81a91c26")) self.assertEqual(header.source_cid, binascii.unhexlify("f49243784f9bf3be")) self.assertEqual(header.original_destination_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.rest_length, 8) self.assertEqual(buf.tell(), 22)
def test_pull_retry(self): original_destination_cid = binascii.unhexlify("fbbd219b7363b64b") data = load("retry.bin") buf = Buffer(data=data) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) self.assertEqual(header.packet_type, PACKET_TYPE_RETRY) self.assertEqual(header.destination_cid, binascii.unhexlify("e9d146d8d14cb28e")) self.assertEqual( header.source_cid, binascii.unhexlify("0b0a205a648fcf82d85f128b67bbe08053e6"), ) self.assertEqual( header.token, binascii.unhexlify( "44397a35d698393c134b08a932737859f446d3aadd00ed81540c8d8de172" "906d3e7a111b503f9729b8928e7528f9a86a4581f9ebb4cb3b53c283661e" "8530741a99192ee56914c5626998ec0f"), ) self.assertEqual( header.integrity_tag, binascii.unhexlify("4620aafd42f1d630588b27575a12da5c")) self.assertEqual(header.rest_length, 0) self.assertEqual(buf.tell(), 125) # check integrity if False: self.assertEqual( get_retry_integrity_tag( buf.data_slice(0, 109), original_destination_cid, version=header.version, ), header.integrity_tag, ) # serialize encoded = encode_quic_retry( version=header.version, source_cid=header.source_cid, destination_cid=header.destination_cid, original_destination_cid=original_destination_cid, retry_token=header.token, ) with open("bob.bin", "wb") as fp: fp.write(encoded) self.assertEqual(encoded, data)
def test_pull_retry_draft_28(self): original_destination_cid = binascii.unhexlify("fbbd219b7363b64b") data = load("retry_draft_28.bin") buf = Buffer(data=data) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.DRAFT_28) self.assertEqual(header.packet_type, PACKET_TYPE_RETRY) self.assertEqual(header.destination_cid, binascii.unhexlify("e9d146d8d14cb28e")) self.assertEqual( header.source_cid, binascii.unhexlify("0b0a205a648fcf82d85f128b67bbe08053e6"), ) self.assertEqual( header.token, binascii.unhexlify( "44397a35d698393c134b08a932737859f446d3aadd00ed81540c8d8de172" "906d3e7a111b503f9729b8928e7528f9a86a4581f9ebb4cb3b53c283661e" "8530741a99192ee56914c5626998ec0f"), ) self.assertEqual( header.integrity_tag, binascii.unhexlify("f15154a271f10139ef6b129033ac38ae")) self.assertEqual(header.rest_length, 0) self.assertEqual(buf.tell(), 125) # check integrity self.assertEqual( get_retry_integrity_tag(buf.data_slice(0, 109), original_destination_cid, version=header.version), header.integrity_tag, ) # serialize encoded = encode_quic_retry( version=header.version, source_cid=header.source_cid, destination_cid=header.destination_cid, original_destination_cid=original_destination_cid, retry_token=header.token, ) self.assertEqual(encoded, data)
def test_pull_retry(self): buf = Buffer(data=load("retry.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.DRAFT_22) self.assertEqual(header.packet_type, PACKET_TYPE_RETRY) self.assertEqual(header.destination_cid, binascii.unhexlify("fee746dfde699d61")) self.assertEqual(header.source_cid, binascii.unhexlify("59aa0942fd2f11e9")) self.assertEqual( header.original_destination_cid, binascii.unhexlify("d61e7448e0d63dff") ) self.assertEqual( header.token, binascii.unhexlify( "5282f57f85a1a5c50de5aac2ff7dba43ff34524737099ec41c4b8e8c76734f935e8efd51177dbbe764" ), ) self.assertEqual(header.rest_length, 0) self.assertEqual(buf.tell(), 73)
def test_pull_version_negotiation(self): buf = Buffer(data=load("version_negotiation.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION) self.assertEqual(header.packet_type, None) self.assertEqual(header.destination_cid, binascii.unhexlify("9aac5a49ba87a849")) self.assertEqual(header.source_cid, binascii.unhexlify("f92f4336fa951ba1")) self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") self.assertEqual(header.rest_length, 8) self.assertEqual(buf.tell(), 23) versions = [] while not buf.eof(): versions.append(buf.pull_uint32()) self.assertEqual(versions, [0x45474716, QuicProtocolVersion.VERSION_1]),
def test_pull_retry(self): buf = Buffer(data=load("retry.bin")) header = pull_quic_header(buf, host_cid_length=8) self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.DRAFT_19) self.assertEqual(header.packet_type, PACKET_TYPE_RETRY) self.assertEqual(header.destination_cid, binascii.unhexlify("c98343fe8f5f0ff4")) self.assertEqual( header.source_cid, binascii.unhexlify("c17f7c0473e635351b85a17e9f3296d7246c"), ) self.assertEqual(header.original_destination_cid, binascii.unhexlify("85abb547bf28be97")) self.assertEqual( header.token, binascii.unhexlify( "01652d68d17c8e9f968d4fb4b70c9e526c4f837dbd85abb547bf28be97"), ) self.assertEqual(header.rest_length, 0) self.assertEqual(buf.tell(), 69)
def test_push_uint16(self): buf = Buffer(capacity=2) push_uint16(buf, 0x0807) self.assertEqual(buf.data, b"\x08\x07") self.assertEqual(buf.tell(), 2)
def test_push_uint8(self): buf = Buffer(capacity=1) push_uint8(buf, 0x08) self.assertEqual(buf.data, b"\x08") self.assertEqual(buf.tell(), 1)
def test_push_bytes(self): buf = Buffer(capacity=3) push_bytes(buf, b"\x08\x07\x06") self.assertEqual(buf.data, b"\x08\x07\x06") self.assertEqual(buf.tell(), 3)
def test_pull_uint64_truncated(self): buf = Buffer(capacity=7) with self.assertRaises(BufferReadError): pull_uint64(buf) self.assertEqual(buf.tell(), 0)
def _receive_stream_data_uni(self, stream: H3Stream, data: bytes, stream_ended: bool) -> List[H3Event]: http_events: List[H3Event] = [] stream.buffer += data if stream_ended: stream.ended = True buf = Buffer(data=stream.buffer) consumed = 0 unblocked_streams: Set[int] = set() while stream.stream_type == StreamType.PUSH or not buf.eof(): # fetch stream type for unidirectional streams if stream.stream_type is None: try: stream.stream_type = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() # check unicity if stream.stream_type == StreamType.CONTROL: if self._peer_control_stream_id is not None: raise StreamCreationError( "Only one control stream is allowed") self._peer_control_stream_id = stream.stream_id elif stream.stream_type == StreamType.QPACK_DECODER: if self._peer_decoder_stream_id is not None: raise StreamCreationError( "Only one QPACK decoder stream is allowed") self._peer_decoder_stream_id = stream.stream_id elif stream.stream_type == StreamType.QPACK_ENCODER: if self._peer_encoder_stream_id is not None: raise StreamCreationError( "Only one QPACK encoder stream is allowed") self._peer_encoder_stream_id = stream.stream_id if stream.stream_type == StreamType.CONTROL: # fetch next frame try: frame_type = buf.pull_uint_var() frame_length = buf.pull_uint_var() frame_data = buf.pull_bytes(frame_length) except BufferReadError: break consumed = buf.tell() self._handle_control_frame(frame_type, frame_data) elif stream.stream_type == StreamType.PUSH: # fetch push id if stream.push_id is None: try: stream.push_id = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() # remove processed data from buffer stream.buffer = stream.buffer[consumed:] return self._receive_request_or_push_data( stream, b"", stream_ended) elif stream.stream_type == StreamType.QPACK_DECODER: # feed unframed data to decoder data = buf.pull_bytes(buf.capacity - buf.tell()) consumed = buf.tell() try: self._encoder.feed_decoder(data) except pylsqpack.DecoderStreamError as exc: raise QpackDecoderStreamError() from exc self._decoder_bytes_received += len(data) elif stream.stream_type == StreamType.QPACK_ENCODER: # feed unframed data to encoder data = buf.pull_bytes(buf.capacity - buf.tell()) consumed = buf.tell() try: unblocked_streams.update(self._decoder.feed_encoder(data)) except pylsqpack.EncoderStreamError as exc: raise QpackEncoderStreamError() from exc self._encoder_bytes_received += len(data) else: # unknown stream type, discard data buf.seek(buf.capacity) consumed = buf.tell() # remove processed data from buffer stream.buffer = stream.buffer[consumed:] # process unblocked streams for stream_id in unblocked_streams: stream = self._stream[stream_id] # resume headers http_events.extend( self._handle_request_or_push_frame( frame_type=FrameType.HEADERS, frame_data=None, stream=stream, stream_ended=stream.ended and not stream.buffer, )) stream.blocked = False stream.blocked_frame_size = None # resume processing if stream.buffer: http_events.extend( self._receive_request_or_push_data(stream, b"", stream.ended)) return http_events
def _handle_request_or_push_frame( self, frame_type: int, frame_data: Optional[bytes], stream: H3Stream, stream_ended: bool, ) -> List[H3Event]: """ Handle a frame received on a request or push stream. """ http_events: List[H3Event] = [] if frame_type == FrameType.DATA: # check DATA frame is allowed if stream.headers_recv_state != HeadersState.AFTER_HEADERS: raise FrameUnexpected( "DATA frame is not allowed in this state") if stream_ended or frame_data: http_events.append( DataReceived( data=frame_data, push_id=stream.push_id, stream_ended=stream_ended, stream_id=stream.stream_id, )) elif frame_type == FrameType.HEADERS: # check HEADERS frame is allowed if stream.headers_recv_state == HeadersState.AFTER_TRAILERS: raise FrameUnexpected( "HEADERS frame is not allowed in this state") # try to decode HEADERS, may raise pylsqpack.StreamBlocked headers = self._decode_headers(stream.stream_id, frame_data) # log frame if self._quic_logger is not None: self._quic_logger.log_event( category="http", event="frame_parsed", data=qlog_encode_headers_frame( byte_length=stream.blocked_frame_size if frame_data is None else len(frame_data), headers=headers, stream_id=stream.stream_id, ), ) # update state and emit headers if stream.headers_recv_state == HeadersState.INITIAL: stream.headers_recv_state = HeadersState.AFTER_HEADERS else: stream.headers_recv_state = HeadersState.AFTER_TRAILERS http_events.append( HeadersReceived( headers=headers, push_id=stream.push_id, stream_id=stream.stream_id, stream_ended=stream_ended, )) elif stream.frame_type == FrameType.PUSH_PROMISE and stream.push_id is None: if not self._is_client: raise FrameUnexpected("Clients must not send PUSH_PROMISE") frame_buf = Buffer(data=frame_data) push_id = frame_buf.pull_uint_var() headers = self._decode_headers(stream.stream_id, frame_data[frame_buf.tell():]) # log frame if self._quic_logger is not None: self._quic_logger.log_event( category="http", event="frame_parsed", data=qlog_encode_push_promise_frame( byte_length=len(frame_data), headers=headers, push_id=push_id, stream_id=stream.stream_id, ), ) # emit event http_events.append( PushPromiseReceived(headers=headers, push_id=push_id, stream_id=stream.stream_id)) elif frame_type in ( FrameType.PRIORITY, FrameType.CANCEL_PUSH, FrameType.SETTINGS, FrameType.PUSH_PROMISE, FrameType.GOAWAY, FrameType.MAX_PUSH_ID, FrameType.DUPLICATE_PUSH, ): raise FrameUnexpected( "Invalid frame type on request stream" if stream. push_id is None else "Invalid frame type on push stream") return http_events
def test_push_bytes_truncated(self): buf = Buffer(capacity=3) with self.assertRaises(BufferWriteError): push_bytes(buf, b"\x08\x07\x06\x05") self.assertEqual(buf.tell(), 0)
def test_push_uint32(self): buf = Buffer(capacity=4) push_uint32(buf, 0x08070605) self.assertEqual(buf.data, b"\x08\x07\x06\x05") self.assertEqual(buf.tell(), 4)
def test_push_uint64(self): buf = Buffer(capacity=8) push_uint64(buf, 0x0807060504030201) self.assertEqual(buf.data, b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(buf.tell(), 8)
def test_pull_bytes_truncated(self): buf = Buffer(capacity=0) with self.assertRaises(BufferReadError): pull_bytes(buf, 2) self.assertEqual(buf.tell(), 0)
def _receive_request_or_push_data(self, stream: H3Stream, data: bytes, stream_ended: bool) -> List[H3Event]: """ Handle data received on a request or push stream. """ http_events: List[H3Event] = [] stream.buffer += data if stream_ended: stream.ended = True if stream.blocked: return http_events # shortcut for DATA frame fragments if (stream.frame_type == FrameType.DATA and stream.frame_size is not None and len(stream.buffer) < stream.frame_size): http_events.append( DataReceived( data=stream.buffer, push_id=stream.push_id, stream_id=stream.stream_id, stream_ended=False, )) stream.frame_size -= len(stream.buffer) stream.buffer = b"" return http_events # handle lone FIN if stream_ended and not stream.buffer: http_events.append( DataReceived( data=b"", push_id=stream.push_id, stream_id=stream.stream_id, stream_ended=True, )) return http_events buf = Buffer(data=stream.buffer) consumed = 0 while not buf.eof(): # fetch next frame header if stream.frame_size is None: try: stream.frame_type = buf.pull_uint_var() stream.frame_size = buf.pull_uint_var() except BufferReadError: break consumed = buf.tell() # log frame if (self._quic_logger is not None and stream.frame_type == FrameType.DATA): self._quic_logger.log_event( category="http", event="frame_parsed", data=qlog_encode_data_frame( byte_length=stream.frame_size, stream_id=stream.stream_id), ) # check how much data is available chunk_size = min(stream.frame_size, buf.capacity - consumed) if stream.frame_type != FrameType.DATA and chunk_size < stream.frame_size: break # read available data frame_data = buf.pull_bytes(chunk_size) consumed = buf.tell() # detect end of frame stream.frame_size -= chunk_size if not stream.frame_size: stream.frame_size = None try: http_events.extend( self._handle_request_or_push_frame( frame_type=stream.frame_type, frame_data=frame_data, stream=stream, stream_ended=stream.ended and buf.eof(), )) except pylsqpack.StreamBlocked: stream.blocked = True stream.blocked_frame_size = len(frame_data) break # remove processed data from buffer stream.buffer = stream.buffer[consumed:] return http_events
def test_pull_uint32(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(pull_uint32(buf), 0x08070605) self.assertEqual(buf.tell(), 4)
class H3CapsuleDecoder: """ A decoder of H3Capsule. This is a streaming decoder and can handle multiple decoders. """ def __init__(self) -> None: self._buffer: Optional[Buffer] = None self._type: Optional[int] = None self._length: Optional[int] = None self._final: bool = False def append(self, data: bytes) -> None: """ Appends the given bytes to this decoder. """ assert not self._final if len(data) == 0: return if self._buffer: remaining = self._buffer.pull_bytes(self._buffer.capacity - self._buffer.tell()) self._buffer = Buffer(data=(remaining + data)) else: self._buffer = Buffer(data=data) def final(self) -> None: """ Pushes the end-of-stream mark to this decoder. After calling this, calling append() will be invalid. """ self._final = True def __iter__(self) -> Iterator[H3Capsule]: """ Yields decoded capsules. """ try: while self._buffer is not None: if self._type is None: self._type = self._buffer.pull_uint_var() if self._length is None: self._length = self._buffer.pull_uint_var() if self._buffer.capacity - self._buffer.tell() < self._length: if self._final: raise ValueError('insufficient buffer') return capsule = H3Capsule(self._type, self._buffer.pull_bytes(self._length)) self._type = None self._length = None if self._buffer.tell() == self._buffer.capacity: self._buffer = None yield capsule except BufferReadError as e: if self._final: raise e if not self._buffer: return 0 size = self._buffer.capacity - self._buffer.tell() if size >= UINT_VAR_MAX_SIZE: raise e # Ignore the error because there may not be sufficient input. return
def test_pull_uint64(self): buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") self.assertEqual(pull_uint64(buf), 0x0807060504030201) self.assertEqual(buf.tell(), 8)