예제 #1
0
    def test_pull_encrypted_extensions_with_alpn_and_early_data(self):
        buf = Buffer(data=load("tls_encrypted_extensions_with_alpn_and_early_data.bin"))
        extensions = pull_encrypted_extensions(buf)
        self.assertIsNotNone(extensions)
        self.assertTrue(buf.eof())

        self.assertEqual(
            extensions,
            EncryptedExtensions(
                alpn_protocol="hq-20",
                early_data=True,
                other_extensions=[
                    (tls.ExtensionType.SERVER_NAME, b""),
                    (
                        tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS,
                        SERVER_QUIC_TRANSPORT_PARAMETERS_3,
                    ),
                ],
            ),
        )

        # serialize
        buf = Buffer(116)
        push_encrypted_extensions(buf, extensions)
        self.assertTrue(buf.eof())
예제 #2
0
    def test_seek(self):
        buf = Buffer(data=b"01234567")
        self.assertFalse(buf.eof())
        self.assertEqual(buf.tell(), 0)

        buf.seek(4)
        self.assertFalse(buf.eof())
        self.assertEqual(buf.tell(), 4)

        buf.seek(8)
        self.assertTrue(buf.eof())
        self.assertEqual(buf.tell(), 8)
예제 #3
0
    def test_pull_certificate(self):
        buf = Buffer(data=load("tls_certificate.bin"))
        certificate = pull_certificate(buf)
        self.assertTrue(buf.eof())

        self.assertEqual(certificate.request_context, b"")
        self.assertEqual(certificate.certificates, [(CERTIFICATE_DATA, b"")])
예제 #4
0
    def test_pull_new_session_ticket_with_unknown_extension(self):
        buf = Buffer(data=load("tls_new_session_ticket_with_unknown_extension.bin"))
        new_session_ticket = pull_new_session_ticket(buf)
        self.assertIsNotNone(new_session_ticket)
        self.assertTrue(buf.eof())

        self.assertEqual(
            new_session_ticket,
            NewSessionTicket(
                ticket_lifetime=86400,
                ticket_age_add=3303452425,
                ticket_nonce=b"",
                ticket=binascii.unhexlify(
                    "dbe6f1a77a78c0426bfa607cd0d02b350247d90618704709596beda7e962cc81"
                ),
                max_early_data_size=4294967295,
                other_extensions=[(12345, b"foo")],
            ),
        )

        # serialize
        buf = Buffer(100)
        push_new_session_ticket(buf, new_session_ticket)
        self.assertEqual(
            buf.data, load("tls_new_session_ticket_with_unknown_extension.bin")
        )
예제 #5
0
    def test_pull_client_hello_with_psk(self):
        buf = Buffer(data=load("tls_client_hello_with_psk.bin"))
        hello = pull_client_hello(buf)

        self.assertEqual(hello.early_data, True)
        self.assertEqual(
            hello.pre_shared_key,
            tls.OfferedPsks(
                identities=[
                    (
                        binascii.unhexlify(
                            "fab3dc7d79f35ea53e9adf21150e601591a750b80cde0cd167fef6e0cdbc032a"
                            "c4161fc5c5b66679de49524bd5624c50d71ba3e650780a4bfe402d6a06a00525"
                            "0b5dc52085233b69d0dd13924cc5c713a396784ecafc59f5ea73c1585d79621b"
                            "8a94e4f2291b17427d5185abf4a994fca74ee7a7f993a950c71003fc7cf8"
                        ),
                        2067156378,
                    )
                ],
                binders=[
                    binascii.unhexlify(
                        "1788ad43fdff37cfc628f24b6ce7c8c76180705380da17da32811b5bae4e78"
                        "d7aaaf65a9b713872f2bb28818ca1a6b01"
                    )
                ],
            ),
        )

        self.assertTrue(buf.eof())

        # serialize
        buf = Buffer(1000)
        push_client_hello(buf, hello)
        self.assertEqual(buf.data, load("tls_client_hello_with_psk.bin"))
예제 #6
0
    def test_pull_server_hello(self):
        buf = Buffer(data=load("tls_server_hello.bin"))
        hello = pull_server_hello(buf)
        self.assertTrue(buf.eof())

        self.assertEqual(
            hello.random,
            binascii.unhexlify(
                "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280"
            ),
        )
        self.assertEqual(
            hello.session_id,
            binascii.unhexlify(
                "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235"
            ),
        )
        self.assertEqual(hello.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384)
        self.assertEqual(hello.compression_method, tls.CompressionMethod.NULL)
        self.assertEqual(
            hello.key_share,
            (
                tls.Group.SECP256R1,
                binascii.unhexlify(
                    "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189"
                    "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf"
                    "b2"
                ),
            ),
        )
        self.assertEqual(hello.pre_shared_key, None)
        self.assertEqual(hello.supported_version, tls.TLS_VERSION_1_3)
예제 #7
0
    def test_pull_server_hello_with_unknown_extension(self):
        buf = Buffer(data=load("tls_server_hello_with_unknown_extension.bin"))
        hello = pull_server_hello(buf)
        self.assertTrue(buf.eof())

        self.assertEqual(
            hello,
            ServerHello(
                random=binascii.unhexlify(
                    "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280"
                ),
                session_id=binascii.unhexlify(
                    "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235"
                ),
                cipher_suite=tls.CipherSuite.AES_256_GCM_SHA384,
                compression_method=tls.CompressionMethod.NULL,
                key_share=(
                    tls.Group.SECP256R1,
                    binascii.unhexlify(
                        "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189"
                        "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf"
                        "b2"
                    ),
                ),
                supported_version=tls.TLS_VERSION_1_3,
                other_extensions=[(12345, b"foo")],
            ),
        )

        # serialize
        buf = Buffer(1000)
        push_server_hello(buf, hello)
        self.assertEqual(buf.data, load("tls_server_hello_with_unknown_extension.bin"))
예제 #8
0
    def test_pull_certificate_verify(self):
        buf = Buffer(data=load("tls_certificate_verify.bin"))
        verify = pull_certificate_verify(buf)
        self.assertTrue(buf.eof())

        self.assertEqual(verify.algorithm, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256)
        self.assertEqual(verify.signature, CERTIFICATE_VERIFY_SIGNATURE)
예제 #9
0
def parse_settings(data: bytes) -> Dict[int, int]:
    buf = Buffer(data=data)
    settings = []
    while not buf.eof():
        setting = buf.pull_uint_var()
        value = buf.pull_uint_var()
        settings.append((setting, value))
    return dict(settings)
예제 #10
0
    def test_seek(self):
        buf = Buffer(data=b"01234567")
        self.assertFalse(buf.eof())
        self.assertEqual(buf.tell(), 0)

        buf.seek(4)
        self.assertFalse(buf.eof())
        self.assertEqual(buf.tell(), 4)

        buf.seek(8)
        self.assertTrue(buf.eof())
        self.assertEqual(buf.tell(), 8)

        with self.assertRaises(BufferReadError):
            buf.seek(-1)
        self.assertEqual(buf.tell(), 8)
        with self.assertRaises(BufferReadError):
            buf.seek(9)
        self.assertEqual(buf.tell(), 8)
예제 #11
0
    def test_pull_finished(self):
        buf = Buffer(data=load("tls_finished.bin"))
        finished = pull_finished(buf)
        self.assertTrue(buf.eof())

        self.assertEqual(
            finished.verify_data,
            binascii.unhexlify(
                "f157923234ff9a4921aadb2e0ec7b1a30fce73fb9ec0c4276f9af268f408ec68"
            ),
        )
예제 #12
0
def parse_settings(data: bytes) -> Dict[int, int]:
    buf = Buffer(data=data)
    settings: Dict[int, int] = {}
    while not buf.eof():
        setting = buf.pull_uint_var()
        value = buf.pull_uint_var()
        if setting in RESERVED_SETTINGS:
            raise SettingsError("Setting identifier 0x%x is reserved" %
                                setting)
        if setting in settings:
            raise SettingsError("Setting identifier 0x%x is included twice" %
                                setting)
        settings[setting] = value
    return dict(settings)
예제 #13
0
파일: test_tls.py 프로젝트: sysuljx/aioquic
    def test_encrypted_extensions(self):
        data = load("tls_encrypted_extensions.bin")
        buf = Buffer(data=data)
        extensions = pull_encrypted_extensions(buf)
        self.assertIsNotNone(extensions)
        self.assertTrue(buf.eof())

        self.assertEqual(
            extensions,
            EncryptedExtensions(other_extensions=[(
                tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS,
                SERVER_QUIC_TRANSPORT_PARAMETERS,
            )]),
        )

        # serialize
        buf = Buffer(capacity=100)
        push_encrypted_extensions(buf, extensions)
        self.assertEqual(buf.data, data)
예제 #14
0
    def test_pull_version_negotiation(self):
        buf = Buffer(data=load("version_negotiation.bin"))
        header = pull_quic_header(buf, host_cid_length=8)
        self.assertTrue(header.is_long_header)
        self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION)
        self.assertEqual(header.packet_type, None)
        self.assertEqual(header.destination_cid,
                         binascii.unhexlify("9aac5a49ba87a849"))
        self.assertEqual(header.source_cid,
                         binascii.unhexlify("f92f4336fa951ba1"))
        self.assertEqual(header.token, b"")
        self.assertEqual(header.integrity_tag, b"")
        self.assertEqual(header.rest_length, 8)
        self.assertEqual(buf.tell(), 23)

        versions = []
        while not buf.eof():
            versions.append(buf.pull_uint32())
        self.assertEqual(versions,
                         [0x45474716, QuicProtocolVersion.VERSION_1]),
예제 #15
0
    def test_pull_server_hello_with_psk(self):
        buf = Buffer(data=load("tls_server_hello_with_psk.bin"))
        hello = pull_server_hello(buf)
        self.assertTrue(buf.eof())

        self.assertEqual(
            hello.random,
            binascii.unhexlify(
                "ccbaaf04fc1bd5143b2cc6b97520cf37d91470dbfc8127131a7bf0f941e3a137"
            ),
        )
        self.assertEqual(
            hello.session_id,
            binascii.unhexlify(
                "9483e7e895d0f4cec17086b0849601c0632662cd764e828f2f892f4c4b7771b0"
            ),
        )
        self.assertEqual(hello.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384)
        self.assertEqual(hello.compression_method, tls.CompressionMethod.NULL)
        self.assertEqual(
            hello.key_share,
            (
                tls.Group.SECP256R1,
                binascii.unhexlify(
                    "0485d7cecbebfc548fc657bf51b8e8da842a4056b164a27f7702ca318c16e488"
                    "18b6409593b15c6649d6f459387a53128b164178adc840179aad01d36ce95d62"
                    "76"
                ),
            ),
        )
        self.assertEqual(hello.pre_shared_key, 0)
        self.assertEqual(hello.supported_version, tls.TLS_VERSION_1_3)

        # serialize
        buf = Buffer(1000)
        push_server_hello(buf, hello)
        self.assertEqual(buf.data, load("tls_server_hello_with_psk.bin"))
예제 #16
0
    def test_pull_client_hello(self):
        buf = Buffer(data=load("tls_client_hello.bin"))
        hello = pull_client_hello(buf)
        self.assertTrue(buf.eof())

        self.assertEqual(
            hello.random,
            binascii.unhexlify(
                "18b2b23bf3e44b5d52ccfe7aecbc5ff14eadc3d349fabf804d71f165ae76e7d5"
            ),
        )
        self.assertEqual(
            hello.session_id,
            binascii.unhexlify(
                "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235"
            ),
        )
        self.assertEqual(
            hello.cipher_suites,
            [
                tls.CipherSuite.AES_256_GCM_SHA384,
                tls.CipherSuite.AES_128_GCM_SHA256,
                tls.CipherSuite.CHACHA20_POLY1305_SHA256,
            ],
        )
        self.assertEqual(hello.compression_methods, [tls.CompressionMethod.NULL])

        # extensions
        self.assertEqual(hello.alpn_protocols, None)
        self.assertEqual(
            hello.key_share,
            [
                (
                    tls.Group.SECP256R1,
                    binascii.unhexlify(
                        "047bfea344467535054263b75def60cffa82405a211b68d1eb8d1d944e67aef8"
                        "93c7665a5473d032cfaf22a73da28eb4aacae0017ed12557b5791f98a1e84f15"
                        "b0"
                    ),
                )
            ],
        )
        self.assertEqual(
            hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE]
        )
        self.assertEqual(hello.server_name, None)
        self.assertEqual(
            hello.signature_algorithms,
            [
                tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256,
                tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256,
                tls.SignatureAlgorithm.RSA_PKCS1_SHA256,
                tls.SignatureAlgorithm.RSA_PKCS1_SHA1,
            ],
        )
        self.assertEqual(hello.supported_groups, [tls.Group.SECP256R1])
        self.assertEqual(
            hello.supported_versions,
            [
                tls.TLS_VERSION_1_3,
                tls.TLS_VERSION_1_3_DRAFT_28,
                tls.TLS_VERSION_1_3_DRAFT_27,
                tls.TLS_VERSION_1_3_DRAFT_26,
            ],
        )

        self.assertEqual(
            hello.other_extensions,
            [
                (
                    tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS,
                    CLIENT_QUIC_TRANSPORT_PARAMETERS,
                )
            ],
        )
예제 #17
0
def parse_max_push_id(data: bytes) -> int:
    buf = Buffer(data=data)
    max_push_id = buf.pull_uint_var()
    assert buf.eof()
    return max_push_id
예제 #18
0
    def _receive_stream_data(
        self, stream_id: int, data: bytes, stream_ended: bool
    ) -> List[Event]:
        http_events: List[Event] = []

        if stream_id in self._stream_buffers:
            self._stream_buffers[stream_id] += data
        else:
            self._stream_buffers[stream_id] = data
        consumed = 0

        buf = Buffer(data=self._stream_buffers[stream_id])
        while not buf.eof():
            # fetch stream type for unidirectional streams
            if (
                stream_is_unidirectional(stream_id)
                and stream_id not in self._stream_types
            ):
                try:
                    stream_type = buf.pull_uint_var()
                except BufferReadError:
                    break
                consumed = buf.tell()

                if stream_type == StreamType.CONTROL:
                    assert self._peer_control_stream_id is None
                    self._peer_control_stream_id = stream_id
                elif stream_type == StreamType.QPACK_DECODER:
                    assert self._peer_decoder_stream_id is None
                    self._peer_decoder_stream_id = stream_id
                elif stream_type == StreamType.QPACK_ENCODER:
                    assert self._peer_encoder_stream_id is None
                    self._peer_encoder_stream_id = stream_id
                self._stream_types[stream_id] = stream_type

            # fetch next frame
            try:
                frame_type = buf.pull_uint_var()
                frame_length = buf.pull_uint_var()
                frame_data = buf.pull_bytes(frame_length)
            except BufferReadError:
                break
            consumed = buf.tell()

            if (stream_id % 4) == 0:
                # client-initiated bidirectional streams carry requests and responses
                if frame_type == FrameType.DATA:
                    http_events.append(
                        DataReceived(
                            data=frame_data,
                            stream_id=stream_id,
                            stream_ended=stream_ended and buf.eof(),
                        )
                    )
                elif frame_type == FrameType.HEADERS:
                    control, headers = self._decoder.feed_header(stream_id, frame_data)
                    cls = ResponseReceived if self._is_client else RequestReceived
                    http_events.append(
                        cls(
                            headers=headers,
                            stream_id=stream_id,
                            stream_ended=stream_ended and buf.eof(),
                        )
                    )
            elif stream_id == self._peer_control_stream_id:
                # unidirectional control stream
                if frame_type == FrameType.SETTINGS:
                    settings = parse_settings(frame_data)
                    self._encoder.apply_settings(
                        max_table_capacity=settings.get(
                            Setting.QPACK_MAX_TABLE_CAPACITY, 0
                        ),
                        blocked_streams=settings.get(Setting.QPACK_BLOCKED_STREAMS, 0),
                    )

        # remove processed data from buffer
        self._stream_buffers[stream_id] = self._stream_buffers[stream_id][consumed:]

        return http_events
예제 #19
0
    def test_pull_client_hello_with_sni(self):
        buf = Buffer(data=load("tls_client_hello_with_sni.bin"))
        hello = pull_client_hello(buf)
        self.assertTrue(buf.eof())

        self.assertEqual(
            hello.random,
            binascii.unhexlify(
                "987d8934140b0a42cc5545071f3f9f7f61963d7b6404eb674c8dbe513604346b"
            ),
        )
        self.assertEqual(
            hello.session_id,
            binascii.unhexlify(
                "26b19bdd30dbf751015a3a16e13bd59002dfe420b799d2a5cd5e11b8fa7bcb66"
            ),
        )
        self.assertEqual(
            hello.cipher_suites,
            [
                tls.CipherSuite.AES_256_GCM_SHA384,
                tls.CipherSuite.AES_128_GCM_SHA256,
                tls.CipherSuite.CHACHA20_POLY1305_SHA256,
            ],
        )
        self.assertEqual(hello.compression_methods, [tls.CompressionMethod.NULL])

        # extensions
        self.assertEqual(hello.alpn_protocols, None)
        self.assertEqual(
            hello.key_share,
            [
                (
                    tls.Group.SECP256R1,
                    binascii.unhexlify(
                        "04b62d70f907c814cd65d0f73b8b991f06b70c77153f548410a191d2b19764a2"
                        "ecc06065a480efa9e1f10c8da6e737d5bfc04be3f773e20a0c997f51b5621280"
                        "40"
                    ),
                )
            ],
        )
        self.assertEqual(
            hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE]
        )
        self.assertEqual(hello.server_name, "cloudflare-quic.com")
        self.assertEqual(
            hello.signature_algorithms,
            [
                tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256,
                tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256,
                tls.SignatureAlgorithm.RSA_PKCS1_SHA256,
                tls.SignatureAlgorithm.RSA_PKCS1_SHA1,
            ],
        )
        self.assertEqual(hello.supported_groups, [tls.Group.SECP256R1])
        self.assertEqual(
            hello.supported_versions,
            [
                tls.TLS_VERSION_1_3,
                tls.TLS_VERSION_1_3_DRAFT_28,
                tls.TLS_VERSION_1_3_DRAFT_27,
                tls.TLS_VERSION_1_3_DRAFT_26,
            ],
        )

        self.assertEqual(
            hello.other_extensions,
            [
                (
                    tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS,
                    CLIENT_QUIC_TRANSPORT_PARAMETERS,
                )
            ],
        )

        # serialize
        buf = Buffer(1000)
        push_client_hello(buf, hello)
        self.assertEqual(buf.data, load("tls_client_hello_with_sni.bin"))
예제 #20
0
    def test_pull_client_hello_with_alpn(self):
        buf = Buffer(data=load("tls_client_hello_with_alpn.bin"))
        hello = pull_client_hello(buf)
        self.assertTrue(buf.eof())

        self.assertEqual(
            hello.random,
            binascii.unhexlify(
                "ed575c6fbd599c4dfaabd003dca6e860ccdb0e1782c1af02e57bf27cb6479b76"
            ),
        )
        self.assertEqual(hello.session_id, b"")
        self.assertEqual(
            hello.cipher_suites,
            [
                tls.CipherSuite.AES_128_GCM_SHA256,
                tls.CipherSuite.AES_256_GCM_SHA384,
                tls.CipherSuite.CHACHA20_POLY1305_SHA256,
                tls.CipherSuite.EMPTY_RENEGOTIATION_INFO_SCSV,
            ],
        )
        self.assertEqual(hello.compression_methods, [tls.CompressionMethod.NULL])

        # extensions
        self.assertEqual(hello.alpn_protocols, ["h3-19"])
        self.assertEqual(hello.early_data, False)
        self.assertEqual(
            hello.key_share,
            [
                (
                    tls.Group.SECP256R1,
                    binascii.unhexlify(
                        "048842315c437bb0ce2929c816fee4e942ec5cb6db6a6b9bf622680188ebb0d4"
                        "b652e69033f71686aa01cbc79155866e264c9f33f45aa16b0dfa10a222e3a669"
                        "22"
                    ),
                )
            ],
        )
        self.assertEqual(
            hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE]
        )
        self.assertEqual(hello.server_name, "cloudflare-quic.com")
        self.assertEqual(
            hello.signature_algorithms,
            [
                tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256,
                tls.SignatureAlgorithm.ECDSA_SECP384R1_SHA384,
                tls.SignatureAlgorithm.ECDSA_SECP521R1_SHA512,
                tls.SignatureAlgorithm.ED25519,
                tls.SignatureAlgorithm.ED448,
                tls.SignatureAlgorithm.RSA_PSS_PSS_SHA256,
                tls.SignatureAlgorithm.RSA_PSS_PSS_SHA384,
                tls.SignatureAlgorithm.RSA_PSS_PSS_SHA512,
                tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256,
                tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA384,
                tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA512,
                tls.SignatureAlgorithm.RSA_PKCS1_SHA256,
                tls.SignatureAlgorithm.RSA_PKCS1_SHA384,
                tls.SignatureAlgorithm.RSA_PKCS1_SHA512,
            ],
        )
        self.assertEqual(
            hello.supported_groups,
            [
                tls.Group.SECP256R1,
                tls.Group.X25519,
                tls.Group.SECP384R1,
                tls.Group.SECP521R1,
            ],
        )
        self.assertEqual(hello.supported_versions, [tls.TLS_VERSION_1_3])

        # serialize
        buf = Buffer(1000)
        push_client_hello(buf, hello)
        self.assertEqual(len(buf.data), len(load("tls_client_hello_with_alpn.bin")))
예제 #21
0
    def _receive_stream_data_uni(self, stream_id: int,
                                 data: bytes) -> List[HttpEvent]:
        http_events: List[HttpEvent] = []

        stream = self._stream[stream_id]
        stream.buffer += data

        buf = Buffer(data=stream.buffer)
        consumed = 0
        unblocked_streams: Set[int] = set()

        while not buf.eof():
            # fetch stream type for unidirectional streams
            if stream.stream_type is None:
                try:
                    stream.stream_type = buf.pull_uint_var()
                except BufferReadError:
                    break
                consumed = buf.tell()

                if stream.stream_type == StreamType.CONTROL:
                    assert self._peer_control_stream_id is None
                    self._peer_control_stream_id = stream_id
                elif stream.stream_type == StreamType.QPACK_DECODER:
                    assert self._peer_decoder_stream_id is None
                    self._peer_decoder_stream_id = stream_id
                elif stream.stream_type == StreamType.QPACK_ENCODER:
                    assert self._peer_encoder_stream_id is None
                    self._peer_encoder_stream_id = stream_id

            if stream_id == self._peer_control_stream_id:
                # fetch next frame
                try:
                    frame_type = buf.pull_uint_var()
                    frame_length = buf.pull_uint_var()
                    frame_data = buf.pull_bytes(frame_length)
                except BufferReadError:
                    break
                consumed = buf.tell()

                # unidirectional control stream
                if frame_type == FrameType.SETTINGS:
                    settings = parse_settings(frame_data)
                    encoder = self._encoder.apply_settings(
                        max_table_capacity=settings.get(
                            Setting.QPACK_MAX_TABLE_CAPACITY, 0),
                        blocked_streams=settings.get(
                            Setting.QPACK_BLOCKED_STREAMS, 0),
                    )
                    self._quic.send_stream_data(self._local_encoder_stream_id,
                                                encoder)
            else:
                # fetch unframed data
                data = buf.pull_bytes(buf.capacity - buf.tell())
                consumed = buf.tell()

                if stream_id == self._peer_decoder_stream_id:
                    self._encoder.feed_decoder(data)

                elif stream_id == self._peer_encoder_stream_id:
                    unblocked_streams.update(self._decoder.feed_encoder(data))

        # remove processed data from buffer
        stream.buffer = stream.buffer[consumed:]

        # process unblocked streams
        for stream_id in unblocked_streams:
            stream = self._stream[stream_id]
            decoder, headers = self._decoder.resume_header(stream_id)
            stream.blocked = False
            cls = ResponseReceived if self._is_client else RequestReceived
            http_events.append(
                cls(
                    headers=headers,
                    stream_id=stream_id,
                    stream_ended=stream.ended and not stream.buffer,
                ))
            http_events.extend(
                self._receive_stream_data_bidi(stream_id, b"", stream.ended))

        return http_events
예제 #22
0
    def _receive_request_or_push_data(self, stream: H3Stream, data: bytes,
                                      stream_ended: bool) -> List[H3Event]:
        """
        Handle data received on a request or push stream.
        """
        http_events: List[H3Event] = []

        stream.buffer += data
        if stream_ended:
            stream.ended = True
        if stream.blocked:
            return http_events

        # shortcut for DATA frame fragments
        if (stream.frame_type == FrameType.DATA
                and stream.frame_size is not None
                and len(stream.buffer) < stream.frame_size):
            http_events.append(
                DataReceived(
                    data=stream.buffer,
                    push_id=stream.push_id,
                    stream_id=stream.stream_id,
                    stream_ended=False,
                ))
            stream.frame_size -= len(stream.buffer)
            stream.buffer = b""
            return http_events

        # handle lone FIN
        if stream_ended and not stream.buffer:
            http_events.append(
                DataReceived(
                    data=b"",
                    push_id=stream.push_id,
                    stream_id=stream.stream_id,
                    stream_ended=True,
                ))
            return http_events

        buf = Buffer(data=stream.buffer)
        consumed = 0

        while not buf.eof():
            # fetch next frame header
            if stream.frame_size is None:
                try:
                    stream.frame_type = buf.pull_uint_var()
                    stream.frame_size = buf.pull_uint_var()
                except BufferReadError:
                    break
                consumed = buf.tell()

                # log frame
                if (self._quic_logger is not None
                        and stream.frame_type == FrameType.DATA):
                    self._quic_logger.log_event(
                        category="http",
                        event="frame_parsed",
                        data=qlog_encode_data_frame(
                            byte_length=stream.frame_size,
                            stream_id=stream.stream_id),
                    )

            # check how much data is available
            chunk_size = min(stream.frame_size, buf.capacity - consumed)
            if stream.frame_type != FrameType.DATA and chunk_size < stream.frame_size:
                break

            # read available data
            frame_data = buf.pull_bytes(chunk_size)
            consumed = buf.tell()

            # detect end of frame
            stream.frame_size -= chunk_size
            if not stream.frame_size:
                stream.frame_size = None

            try:
                http_events.extend(
                    self._handle_request_or_push_frame(
                        frame_type=stream.frame_type,
                        frame_data=frame_data,
                        stream=stream,
                        stream_ended=stream.ended and buf.eof(),
                    ))
            except pylsqpack.StreamBlocked:
                stream.blocked = True
                stream.blocked_frame_size = len(frame_data)
                break

        # remove processed data from buffer
        stream.buffer = stream.buffer[consumed:]

        return http_events
예제 #23
0
    def _receive_stream_data_uni(self, stream: H3Stream, data: bytes,
                                 stream_ended: bool) -> List[H3Event]:
        http_events: List[H3Event] = []

        stream.buffer += data
        if stream_ended:
            stream.ended = True

        buf = Buffer(data=stream.buffer)
        consumed = 0
        unblocked_streams: Set[int] = set()

        while stream.stream_type == StreamType.PUSH or not buf.eof():
            # fetch stream type for unidirectional streams
            if stream.stream_type is None:
                try:
                    stream.stream_type = buf.pull_uint_var()
                except BufferReadError:
                    break
                consumed = buf.tell()

                # check unicity
                if stream.stream_type == StreamType.CONTROL:
                    if self._peer_control_stream_id is not None:
                        raise StreamCreationError(
                            "Only one control stream is allowed")
                    self._peer_control_stream_id = stream.stream_id
                elif stream.stream_type == StreamType.QPACK_DECODER:
                    if self._peer_decoder_stream_id is not None:
                        raise StreamCreationError(
                            "Only one QPACK decoder stream is allowed")
                    self._peer_decoder_stream_id = stream.stream_id
                elif stream.stream_type == StreamType.QPACK_ENCODER:
                    if self._peer_encoder_stream_id is not None:
                        raise StreamCreationError(
                            "Only one QPACK encoder stream is allowed")
                    self._peer_encoder_stream_id = stream.stream_id

            if stream.stream_type == StreamType.CONTROL:
                # fetch next frame
                try:
                    frame_type = buf.pull_uint_var()
                    frame_length = buf.pull_uint_var()
                    frame_data = buf.pull_bytes(frame_length)
                except BufferReadError:
                    break
                consumed = buf.tell()

                self._handle_control_frame(frame_type, frame_data)
            elif stream.stream_type == StreamType.PUSH:
                # fetch push id
                if stream.push_id is None:
                    try:
                        stream.push_id = buf.pull_uint_var()
                    except BufferReadError:
                        break
                    consumed = buf.tell()

                # remove processed data from buffer
                stream.buffer = stream.buffer[consumed:]

                return self._receive_request_or_push_data(
                    stream, b"", stream_ended)
            elif stream.stream_type == StreamType.QPACK_DECODER:
                # feed unframed data to decoder
                data = buf.pull_bytes(buf.capacity - buf.tell())
                consumed = buf.tell()
                try:
                    self._encoder.feed_decoder(data)
                except pylsqpack.DecoderStreamError as exc:
                    raise QpackDecoderStreamError() from exc
                self._decoder_bytes_received += len(data)
            elif stream.stream_type == StreamType.QPACK_ENCODER:
                # feed unframed data to encoder
                data = buf.pull_bytes(buf.capacity - buf.tell())
                consumed = buf.tell()
                try:
                    unblocked_streams.update(self._decoder.feed_encoder(data))
                except pylsqpack.EncoderStreamError as exc:
                    raise QpackEncoderStreamError() from exc
                self._encoder_bytes_received += len(data)
            else:
                # unknown stream type, discard data
                buf.seek(buf.capacity)
                consumed = buf.tell()

        # remove processed data from buffer
        stream.buffer = stream.buffer[consumed:]

        # process unblocked streams
        for stream_id in unblocked_streams:
            stream = self._stream[stream_id]

            # resume headers
            http_events.extend(
                self._handle_request_or_push_frame(
                    frame_type=FrameType.HEADERS,
                    frame_data=None,
                    stream=stream,
                    stream_ended=stream.ended and not stream.buffer,
                ))
            stream.blocked = False
            stream.blocked_frame_size = None

            # resume processing
            if stream.buffer:
                http_events.extend(
                    self._receive_request_or_push_data(stream, b"",
                                                       stream.ended))

        return http_events
예제 #24
0
    def _receive_stream_data_bidi(self, stream_id: int, data: bytes,
                                  stream_ended: bool) -> List[HttpEvent]:
        """
        Client-initiated bidirectional streams carry requests and responses.
        """
        http_events: List[HttpEvent] = []

        stream = self._stream[stream_id]
        stream.buffer += data
        if stream_ended:
            stream.ended = True
        if stream.blocked:
            return http_events

        # shortcut DATA frame bits
        if (stream.frame_size is not None
                and stream.frame_type == FrameType.DATA
                and len(stream.buffer) < stream.frame_size):
            http_events.append(
                DataReceived(data=stream.buffer,
                             stream_id=stream_id,
                             stream_ended=False))
            stream.frame_size -= len(stream.buffer)
            stream.buffer = b""
            return http_events

        # some peers (e.g. f5) end the stream with no data
        if stream_ended and not stream.buffer:
            http_events.append(
                DataReceived(data=b"", stream_id=stream_id, stream_ended=True))
            return http_events

        buf = Buffer(data=stream.buffer)
        consumed = 0

        while not buf.eof():
            # fetch next frame header
            if stream.frame_size is None:
                try:
                    stream.frame_type = buf.pull_uint_var()
                    stream.frame_size = buf.pull_uint_var()
                except BufferReadError:
                    break
                consumed = buf.tell()

            # check how much data is available
            chunk_size = min(stream.frame_size, buf.capacity - consumed)
            if (stream.frame_type == FrameType.HEADERS
                    and chunk_size < stream.frame_size):
                break

            # read available data
            frame_data = buf.pull_bytes(chunk_size)
            consumed = buf.tell()

            # detect end of frame
            stream.frame_size -= chunk_size
            if not stream.frame_size:
                stream.frame_size = None

            if stream.frame_type == FrameType.DATA and (stream_ended
                                                        or frame_data):
                http_events.append(
                    DataReceived(
                        data=frame_data,
                        stream_id=stream_id,
                        stream_ended=stream_ended and buf.eof(),
                    ))
            elif stream.frame_type == FrameType.HEADERS:
                try:
                    decoder, headers = self._decoder.feed_header(
                        stream_id, frame_data)
                except StreamBlocked:
                    stream.blocked = True
                    break
                self._quic.send_stream_data(self._local_decoder_stream_id,
                                            decoder)
                cls = ResponseReceived if self._is_client else RequestReceived
                http_events.append(
                    cls(
                        headers=headers,
                        stream_id=stream_id,
                        stream_ended=stream_ended and buf.eof(),
                    ))

        # remove processed data from buffer
        stream.buffer = stream.buffer[consumed:]

        return http_events