Ejemplo n.º 1
0
 def test_pull_long_header_scid_too_long(self):
     buf = Buffer(data=binascii.unhexlify(
         "c2ff0000160015000000000000000000000000000000000000000000004"
         "01cfcee99ec4bbf1f7a30f9b0c9417b8c263cdd8cc972a4439d68a46320"))
     with self.assertRaises(ValueError) as cm:
         pull_quic_header(buf, host_cid_length=8)
     self.assertEqual(str(cm.exception),
                      "Source CID is too long (21 bytes)")
Ejemplo n.º 2
0
 def test_pull_long_header_dcid_too_long(self):
     buf = Buffer(data=binascii.unhexlify(
         "c6ff0000161500000000000000000000000000000000000000000000004"
         "01c514f99ec4bbf1f7a30f9b0c94fef717f1c1d07fec24c99a864da7ede"))
     with self.assertRaises(ValueError) as cm:
         pull_quic_header(buf, host_cid_length=8)
     self.assertEqual(str(cm.exception),
                      "Destination CID is too long (21 bytes)")
Ejemplo n.º 3
0
    def test_pull_retry(self):
        buf = Buffer(data=load("retry.bin"))
        header = pull_quic_header(buf, host_cid_length=8)
        self.assertTrue(header.is_long_header)
        self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25)
        self.assertEqual(header.packet_type, PACKET_TYPE_RETRY)
        self.assertEqual(header.destination_cid,
                         binascii.unhexlify("e9d146d8d14cb28e"))
        self.assertEqual(
            header.source_cid,
            binascii.unhexlify("0b0a205a648fcf82d85f128b67bbe08053e6"),
        )
        self.assertEqual(
            header.token,
            binascii.unhexlify(
                "44397a35d698393c134b08a932737859f446d3aadd00ed81540c8d8de172"
                "906d3e7a111b503f9729b8928e7528f9a86a4581f9ebb4cb3b53c283661e"
                "8530741a99192ee56914c5626998ec0f"),
        )
        self.assertEqual(
            header.integrity_tag,
            binascii.unhexlify("e1a3c80c797ea401c08fc9c342a2d90d"))
        self.assertEqual(header.rest_length, 0)
        self.assertEqual(buf.tell(), 125)

        # check integrity
        self.assertEqual(
            get_retry_integrity_tag(
                buf.data_slice(0, 109),
                binascii.unhexlify("fbbd219b7363b64b"),
            ),
            header.integrity_tag,
        )
Ejemplo n.º 4
0
    async def handle(self, event: Event) -> None:
        if isinstance(event, RawData):
            try:
                header = pull_quic_header(Buffer(data=event.data), host_cid_length=8)
            except ValueError:
                return
            if (
                header.version is not None
                and header.version not in self.quic_config.supported_versions
            ):
                data = encode_quic_version_negotiation(
                    source_cid=header.destination_cid,
                    destination_cid=header.source_cid,
                    supported_versions=self.quic_config.supported_versions,
                )
                await self.send(RawData(data=data, address=event.address))
                return

            connection = self.connections.get(header.destination_cid)
            if (
                connection is None
                and len(event.data) >= 1200
                and header.packet_type == PACKET_TYPE_INITIAL
            ):
                connection = QuicConnection(
                    configuration=self.quic_config, original_connection_id=None
                )
                self.connections[header.destination_cid] = connection
                self.connections[connection.host_cid] = connection

            if connection is not None:
                connection.receive_datagram(event.data, event.address, now=self.now())
                await self._handle_events(connection, event.address)
        elif isinstance(event, Closed):
            pass
Ejemplo n.º 5
0
    def _make_connection(self, channel, data):
        ctx = channel.server.ctx

        buf = Buffer(data=data)
        header = pull_quic_header(buf,
                                  host_cid_length=ctx.connection_id_length)
        # version negotiation
        if header.version is not None and header.version not in ctx.supported_versions:
            self.channel.push(
                encode_quic_version_negotiation(
                    source_cid=header.destination_cid,
                    destination_cid=header.source_cid,
                    supported_versions=ctx.supported_versions,
                ))
            return

        conn = self.conns.get(header.destination_cid)
        if conn:
            conn._linked_channel.close()
            conn._linked_channel = channel
            self.quic = conn._quic
            self.conn = conn
            return

        if header.packet_type != PACKET_TYPE_INITIAL or len(data) < 1200:
            return

        original_connection_id = None
        if self._retry is not None:
            if not header.token:
                # create a retry token
                channel.push(
                    encode_quic_retry(
                        version=header.version,
                        source_cid=os.urandom(8),
                        destination_cid=header.source_cid,
                        original_destination_cid=header.destination_cid,
                        retry_token=self._retry.create_token(
                            channel.addr, header.destination_cid)))
                return
            else:
                try:
                    original_connection_id = self._retry.validate_token(
                        channel.addr, header.token)
                except ValueError:
                    return

        self.quic = QuicConnection(
            configuration=ctx,
            logger_connection_id=original_connection_id
            or header.destination_cid,
            original_connection_id=original_connection_id,
            session_ticket_fetcher=channel.server.ticket_store.pop,
            session_ticket_handler=channel.server.ticket_store.add)
        self.conn = H3Connection(self.quic)
        self.conn._linked_channel = channel
Ejemplo n.º 6
0
 def test_pull_initial_server(self):
     buf = Buffer(data=load("initial_server.bin"))
     header = pull_quic_header(buf, host_cid_length=8)
     self.assertTrue(header.is_long_header)
     self.assertEqual(header.version, QuicProtocolVersion.DRAFT_22)
     self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL)
     self.assertEqual(header.destination_cid, b"")
     self.assertEqual(header.source_cid, binascii.unhexlify("195c68344e28d479"))
     self.assertEqual(header.original_destination_cid, b"")
     self.assertEqual(header.token, b"")
     self.assertEqual(header.rest_length, 184)
     self.assertEqual(buf.tell(), 18)
Ejemplo n.º 7
0
 def test_pull_version_negotiation(self):
     buf = Buffer(data=load("version_negotiation.bin"))
     header = pull_quic_header(buf, host_cid_length=8)
     self.assertTrue(header.is_long_header)
     self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION)
     self.assertEqual(header.packet_type, None)
     self.assertEqual(header.destination_cid, binascii.unhexlify("9aac5a49ba87a849"))
     self.assertEqual(header.source_cid, binascii.unhexlify("f92f4336fa951ba1"))
     self.assertEqual(header.original_destination_cid, b"")
     self.assertEqual(header.token, b"")
     self.assertEqual(header.rest_length, 8)
     self.assertEqual(buf.tell(), 23)
Ejemplo n.º 8
0
 def test_pull_short_header(self):
     buf = Buffer(data=load("short_header.bin"))
     header = pull_quic_header(buf, host_cid_length=8)
     self.assertFalse(header.is_long_header)
     self.assertEqual(header.version, None)
     self.assertEqual(header.packet_type, 0x50)
     self.assertEqual(header.destination_cid, binascii.unhexlify("f45aa7b59c0e1ad6"))
     self.assertEqual(header.source_cid, b"")
     self.assertEqual(header.original_destination_cid, b"")
     self.assertEqual(header.token, b"")
     self.assertEqual(header.rest_length, 12)
     self.assertEqual(buf.tell(), 9)
Ejemplo n.º 9
0
 def test_pull_initial_client(self):
     buf = Buffer(data=load("initial_client.bin"))
     header = pull_quic_header(buf, host_cid_length=8)
     self.assertTrue(header.is_long_header)
     self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25)
     self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL)
     self.assertEqual(header.destination_cid,
                      binascii.unhexlify("858b39368b8e3c6e"))
     self.assertEqual(header.source_cid, b"")
     self.assertEqual(header.token, b"")
     self.assertEqual(header.integrity_tag, b"")
     self.assertEqual(header.rest_length, 1262)
     self.assertEqual(buf.tell(), 18)
Ejemplo n.º 10
0
    def test_pull_retry(self):
        original_destination_cid = binascii.unhexlify("fbbd219b7363b64b")

        data = load("retry.bin")
        buf = Buffer(data=data)
        header = pull_quic_header(buf, host_cid_length=8)
        self.assertTrue(header.is_long_header)
        self.assertEqual(header.version, QuicProtocolVersion.VERSION_1)
        self.assertEqual(header.packet_type, PACKET_TYPE_RETRY)
        self.assertEqual(header.destination_cid,
                         binascii.unhexlify("e9d146d8d14cb28e"))
        self.assertEqual(
            header.source_cid,
            binascii.unhexlify("0b0a205a648fcf82d85f128b67bbe08053e6"),
        )
        self.assertEqual(
            header.token,
            binascii.unhexlify(
                "44397a35d698393c134b08a932737859f446d3aadd00ed81540c8d8de172"
                "906d3e7a111b503f9729b8928e7528f9a86a4581f9ebb4cb3b53c283661e"
                "8530741a99192ee56914c5626998ec0f"),
        )
        self.assertEqual(
            header.integrity_tag,
            binascii.unhexlify("4620aafd42f1d630588b27575a12da5c"))
        self.assertEqual(header.rest_length, 0)
        self.assertEqual(buf.tell(), 125)

        # check integrity
        if False:
            self.assertEqual(
                get_retry_integrity_tag(
                    buf.data_slice(0, 109),
                    original_destination_cid,
                    version=header.version,
                ),
                header.integrity_tag,
            )

        # serialize
        encoded = encode_quic_retry(
            version=header.version,
            source_cid=header.source_cid,
            destination_cid=header.destination_cid,
            original_destination_cid=original_destination_cid,
            retry_token=header.token,
        )
        with open("bob.bin", "wb") as fp:
            fp.write(encoded)
        self.assertEqual(encoded, data)
Ejemplo n.º 11
0
    def test_pull_retry_draft_28(self):
        original_destination_cid = binascii.unhexlify("fbbd219b7363b64b")

        data = load("retry_draft_28.bin")
        buf = Buffer(data=data)
        header = pull_quic_header(buf, host_cid_length=8)
        self.assertTrue(header.is_long_header)
        self.assertEqual(header.version, QuicProtocolVersion.DRAFT_28)
        self.assertEqual(header.packet_type, PACKET_TYPE_RETRY)
        self.assertEqual(header.destination_cid,
                         binascii.unhexlify("e9d146d8d14cb28e"))
        self.assertEqual(
            header.source_cid,
            binascii.unhexlify("0b0a205a648fcf82d85f128b67bbe08053e6"),
        )
        self.assertEqual(
            header.token,
            binascii.unhexlify(
                "44397a35d698393c134b08a932737859f446d3aadd00ed81540c8d8de172"
                "906d3e7a111b503f9729b8928e7528f9a86a4581f9ebb4cb3b53c283661e"
                "8530741a99192ee56914c5626998ec0f"),
        )
        self.assertEqual(
            header.integrity_tag,
            binascii.unhexlify("f15154a271f10139ef6b129033ac38ae"))
        self.assertEqual(header.rest_length, 0)
        self.assertEqual(buf.tell(), 125)

        # check integrity
        self.assertEqual(
            get_retry_integrity_tag(buf.data_slice(0, 109),
                                    original_destination_cid,
                                    version=header.version),
            header.integrity_tag,
        )

        # serialize
        encoded = encode_quic_retry(
            version=header.version,
            source_cid=header.source_cid,
            destination_cid=header.destination_cid,
            original_destination_cid=original_destination_cid,
            retry_token=header.token,
        )
        self.assertEqual(encoded, data)
Ejemplo n.º 12
0
 def test_pull_retry(self):
     buf = Buffer(data=load("retry.bin"))
     header = pull_quic_header(buf, host_cid_length=8)
     self.assertTrue(header.is_long_header)
     self.assertEqual(header.version, QuicProtocolVersion.DRAFT_22)
     self.assertEqual(header.packet_type, PACKET_TYPE_RETRY)
     self.assertEqual(header.destination_cid, binascii.unhexlify("fee746dfde699d61"))
     self.assertEqual(header.source_cid, binascii.unhexlify("59aa0942fd2f11e9"))
     self.assertEqual(
         header.original_destination_cid, binascii.unhexlify("d61e7448e0d63dff")
     )
     self.assertEqual(
         header.token,
         binascii.unhexlify(
             "5282f57f85a1a5c50de5aac2ff7dba43ff34524737099ec41c4b8e8c76734f935e8efd51177dbbe764"
         ),
     )
     self.assertEqual(header.rest_length, 0)
     self.assertEqual(buf.tell(), 73)
Ejemplo n.º 13
0
    def test_pull_version_negotiation(self):
        buf = Buffer(data=load("version_negotiation.bin"))
        header = pull_quic_header(buf, host_cid_length=8)
        self.assertTrue(header.is_long_header)
        self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION)
        self.assertEqual(header.packet_type, None)
        self.assertEqual(header.destination_cid,
                         binascii.unhexlify("9aac5a49ba87a849"))
        self.assertEqual(header.source_cid,
                         binascii.unhexlify("f92f4336fa951ba1"))
        self.assertEqual(header.token, b"")
        self.assertEqual(header.integrity_tag, b"")
        self.assertEqual(header.rest_length, 8)
        self.assertEqual(buf.tell(), 23)

        versions = []
        while not buf.eof():
            versions.append(buf.pull_uint32())
        self.assertEqual(versions,
                         [0x45474716, QuicProtocolVersion.VERSION_1]),
Ejemplo n.º 14
0
 def test_pull_empty(self):
     buf = Buffer(data=b"")
     with self.assertRaises(BufferReadError):
         pull_quic_header(buf, host_cid_length=8)
Ejemplo n.º 15
0
 def test_pull_short_header_no_fixed_bit(self):
     buf = Buffer(data=b"\x00")
     with self.assertRaises(ValueError) as cm:
         pull_quic_header(buf, host_cid_length=8)
     self.assertEqual(str(cm.exception), "Packet fixed bit is zero")
Ejemplo n.º 16
0
 def test_pull_long_header_too_short(self):
     buf = Buffer(data=b"\xc0\x00")
     with self.assertRaises(BufferReadError):
         pull_quic_header(buf, host_cid_length=8)
Ejemplo n.º 17
0
 def test_pull_initial_client_truncated(self):
     buf = Buffer(data=load("initial_client.bin")[0:100])
     with self.assertRaises(ValueError) as cm:
         pull_quic_header(buf, host_cid_length=8)
     self.assertEqual(str(cm.exception), "Packet payload is truncated")
Ejemplo n.º 18
0
    def datagram_received(self, data: Union[bytes, Text],
                          addr: NetworkAddress) -> None:
        data = cast(bytes, data)
        buf = Buffer(data=data)
        # logger.info("datagram received")
        global totalDatagrams
        totalDatagrams += 1
        # logger.info('total:{} {}'.format(totalDatagrams, addr))
        try:
            header = pull_quic_header(
                buf, host_cid_length=self._configuration.connection_id_length)
        except ValueError:
            return

        # version negotiation
        if (header.version is not None and header.version
                not in self._configuration.supported_versions):
            self._transport.sendto(
                encode_quic_version_negotiation(
                    source_cid=header.destination_cid,
                    destination_cid=header.source_cid,
                    supported_versions=self._configuration.supported_versions,
                ),
                addr,
            )
            return

        protocol = self._protocols.get(header.destination_cid, None)
        original_destination_connection_id: Optional[bytes] = None
        retry_source_connection_id: Optional[bytes] = None
        if (protocol is None and len(data) >= 1200
                and header.packet_type == PACKET_TYPE_INITIAL):
            #retry
            if self._retry is not None:
                if not header.token:
                    # create a retry token
                    source_cid = os.urandom(8)
                    self._transport.sendto(
                        encode_quic_retry(
                            version=header.version,
                            source_cid=source_cid,
                            destination_cid=header.source_cid,
                            original_destination_cid=header.destination_cid,
                            retry_token=self._retry.create_token(
                                addr, header.destination_cid, source_cid),
                        ),
                        addr,
                    )
                    return
                else:
                    # validate retry token
                    try:
                        (original_destination_cid, retry_source_connection_id
                         ) = self._retry.validate_token(addr, header.token)
                    except ValueError:
                        return
            else:
                original_destination_connection_id = header.destination_cid

            # create new connection
            connection = QuicConnection(
                configuration=self._configuration,
                original_destination_connection_id=
                original_destination_connection_id,
                retry_source_connection_id=retry_source_connection_id,
                session_ticket_handler=self._session_ticket_handler,
                session_ticket_fetcher=self._session_ticket_fetcher,
            )

            # initiate the QuicSocketFactory class with the below call.
            protocol = self._create_protocol(
                connection, stream_handler=self._stream_handler)
            protocol.connection_made(self._transport)

            # register callbacks
            protocol._connection_id_issued_handler = partial(
                self._connection_id_issued, protocol=protocol)
            protocol._connection_id_retired_handler = partial(
                self._connection_id_retired, protocol=protocol)
            protocol._connection_terminated_handler = partial(
                self._connection_terminated, protocol=protocol)

            self._protocols[header.destination_cid] = protocol
            self._protocols[connection.host_cid] = protocol

        if protocol is not None:
            protocol.datagram_received(data, addr)