Exemplo n.º 1
0
    async def _handle_events(
        self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None
    ) -> None:
        event = connection.next_event()
        while event is not None:
            if isinstance(event, ConnectionTerminated):
                pass
            elif isinstance(event, ProtocolNegotiated):
                self.http_connections[connection] = H3Protocol(
                    self.config,
                    client,
                    self.server,
                    self.spawn_app,
                    connection,
                    partial(self.send_all, connection),
                )
            elif isinstance(event, ConnectionIdIssued):
                self.connections[event.connection_id] = connection
            elif isinstance(event, ConnectionIdRetired):
                del self.connections[event.connection_id]

            if connection in self.http_connections:
                await self.http_connections[connection].handle(event)

            event = connection.next_event()

        await self.send_all(connection)

        timer = connection.get_timer()
        if timer is not None:
            self.call_at(timer, partial(self._handle_timer, connection))
Exemplo n.º 2
0
    async def handle(self, event: Event) -> None:
        if isinstance(event, RawData):
            try:
                header = pull_quic_header(Buffer(data=event.data), host_cid_length=8)
            except ValueError:
                return
            if (
                header.version is not None
                and header.version not in self.quic_config.supported_versions
            ):
                data = encode_quic_version_negotiation(
                    source_cid=header.destination_cid,
                    destination_cid=header.source_cid,
                    supported_versions=self.quic_config.supported_versions,
                )
                await self.send(RawData(data=data, address=event.address))
                return

            connection = self.connections.get(header.destination_cid)
            if (
                connection is None
                and len(event.data) >= 1200
                and header.packet_type == PACKET_TYPE_INITIAL
            ):
                connection = QuicConnection(
                    configuration=self.quic_config, original_connection_id=None
                )
                self.connections[header.destination_cid] = connection
                self.connections[connection.host_cid] = connection

            if connection is not None:
                connection.receive_datagram(event.data, event.address, now=self.now())
                await self._handle_events(connection, event.address)
        elif isinstance(event, Closed):
            pass
Exemplo n.º 3
0
 async def _handle_timer(self, timer: float,
                         connection: QuicConnection) -> None:
     wait = max(0, timer - self.context.time())
     await self.context.sleep(wait)
     if connection._close_at is not None:
         connection.handle_timer(now=self.context.time())
         await self._handle_events(connection, None)
Exemplo n.º 4
0
    def _make_connection(self, channel, data):
        ctx = channel.server.ctx

        buf = Buffer(data=data)
        header = pull_quic_header(buf,
                                  host_cid_length=ctx.connection_id_length)
        # version negotiation
        if header.version is not None and header.version not in ctx.supported_versions:
            self.channel.push(
                encode_quic_version_negotiation(
                    source_cid=header.destination_cid,
                    destination_cid=header.source_cid,
                    supported_versions=ctx.supported_versions,
                ))
            return

        conn = self.conns.get(header.destination_cid)
        if conn:
            conn._linked_channel.close()
            conn._linked_channel = channel
            self.quic = conn._quic
            self.conn = conn
            return

        if header.packet_type != PACKET_TYPE_INITIAL or len(data) < 1200:
            return

        original_connection_id = None
        if self._retry is not None:
            if not header.token:
                # create a retry token
                channel.push(
                    encode_quic_retry(
                        version=header.version,
                        source_cid=os.urandom(8),
                        destination_cid=header.source_cid,
                        original_destination_cid=header.destination_cid,
                        retry_token=self._retry.create_token(
                            channel.addr, header.destination_cid)))
                return
            else:
                try:
                    original_connection_id = self._retry.validate_token(
                        channel.addr, header.token)
                except ValueError:
                    return

        self.quic = QuicConnection(
            configuration=ctx,
            logger_connection_id=original_connection_id
            or header.destination_cid,
            original_connection_id=original_connection_id,
            session_ticket_fetcher=channel.server.ticket_store.pop,
            session_ticket_handler=channel.server.ticket_store.add)
        self.conn = H3Connection(self.quic)
        self.conn._linked_channel = channel
Exemplo n.º 5
0
 def connect (self):
     host, port = self.addr
     try:
         ipaddress.ip_address(host)
         server_name = None
     except ValueError:
         server_name = host
     if server_name is not None:
         self.configuration.server_name = server_name
     self._quic = QuicConnection (
         configuration = self.configuration, session_ticket_handler = self.save_session_ticket
     )
     self._http = H3Connection(self._quic)
     self._quic.connect(self.addr, now=time.monotonic ())
     self.transmit ()
Exemplo n.º 6
0
async def connect(
    host: str,
    port: int,
    *,
    configuration: Optional[QuicConfiguration] = None,
    create_protocol: Optional[Callable] = QuicFactorySocket,
    session_ticket_handler: Optional[SessionTicketHandler] = None,
    stream_handler: Optional[QuicStreamHandler] = None,
    wait_connected: bool = True,
    local_port: int = 0,
) -> AsyncGenerator[QuicFactorySocket, None]:

    loop = asyncio.get_event_loop()
    local_host = "::"

    try:
        ipaddress.ip_address(host)
        server_name = None
    except ValueError:
        server_name = host

    infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
    addr = infos[0][4]
    if len(addr) == 2:
        addr = ("::ffff:" + addr[0], addr[1], 0, 0)

    #prepare QUIC connection
    if configuration is None:
        configuration = QuicConfiguration(is_client=True)
    if server_name is not None:
        configuration.server_name = server_name
    connection = QuicConnection(configuration=configuration,
                                session_ticket_handler=session_ticket_handler)

    _, protocol = await loop.create_datagram_endpoint(
        lambda: create_protocol(connection, stream_handler=stream_handler),
        local_addr=(local_host, local_port))
    protocol = cast(QuicFactorySocket, protocol)
    protocol.connect(addr)
    if wait_connected:
        await protocol.wait_connected()
    try:
        yield protocol
    finally:
        protocol.close()
    await protocol.wait_closed()
Exemplo n.º 7
0
    async def connection(self, message):

        if message.unresolved_remote is None:
            host = message.opt.uri_host
            port = message.opt.uri_port or self.default_port
            if host is None:
                raise ValueError(
                    "No location found to send message to (neither in .opt.uri_host nor in .remote)"
                )
        else:
            host, port = util.hostportsplit(message.unresolved_remote)
            port = port or self.default_port

        try:
            ipaddress.ip_address(host)
            server_name = None
        except ValueError as ve:
            server_name = host

        infos = await self.loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
        self.addr = infos[0][4]

        config = QuicConfiguration(is_client=True,
                                   alpn_protocols='coap',
                                   idle_timeout=864000,
                                   server_name=server_name)
        config.verify_mode = ssl.CERT_NONE

        if config.server_name is None:
            config.server_name = server_name

        connection = QuicConnection(configuration=config)

        self.quic = Quic(connection)
        self.quic.ctx = self

        try:
            transport, protocol = await self.loop.create_datagram_endpoint(
                lambda: self.quic, remote_addr=(host, port))
            protocol.connect(self.addr)
            await protocol.wait_connected()
            self.con = True
        except OSError:
            raise error.NetworkError("Connection failed to %r" % host)

        return protocol
Exemplo n.º 8
0
 async def send_all(self, connection: QuicConnection) -> None:
     for data, address in connection.datagrams_to_send(now=self.now()):
         await self.send(RawData(data=data, address=address))
Exemplo n.º 9
0
 async def _handle_timer(self, connection: QuicConnection) -> None:
     if connection._close_at is not None:
         connection.handle_timer(now=self.now())
         await self._handle_events(connection, None)
Exemplo n.º 10
0
def make_client():
    client_configuration = QuicConfiguration(is_client=True)
    client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE)
    client = QuicConnection(configuration=client_configuration)
    client._ack_delay = 0
    return client
Exemplo n.º 11
0
def make_connection():
    server_configuration = QuicConfiguration(is_client=False)
    server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE)
    quic = QuicConnection(configuration=server_configuration)
    quic._ack_delay = 0
    return h3.H3Connection(quic)
Exemplo n.º 12
0
    def datagram_received(self, data: Union[bytes, Text],
                          addr: NetworkAddress) -> None:
        data = cast(bytes, data)
        buf = Buffer(data=data)
        # logger.info("datagram received")
        global totalDatagrams
        totalDatagrams += 1
        # logger.info('total:{} {}'.format(totalDatagrams, addr))
        try:
            header = pull_quic_header(
                buf, host_cid_length=self._configuration.connection_id_length)
        except ValueError:
            return

        # version negotiation
        if (header.version is not None and header.version
                not in self._configuration.supported_versions):
            self._transport.sendto(
                encode_quic_version_negotiation(
                    source_cid=header.destination_cid,
                    destination_cid=header.source_cid,
                    supported_versions=self._configuration.supported_versions,
                ),
                addr,
            )
            return

        protocol = self._protocols.get(header.destination_cid, None)
        original_destination_connection_id: Optional[bytes] = None
        retry_source_connection_id: Optional[bytes] = None
        if (protocol is None and len(data) >= 1200
                and header.packet_type == PACKET_TYPE_INITIAL):
            #retry
            if self._retry is not None:
                if not header.token:
                    # create a retry token
                    source_cid = os.urandom(8)
                    self._transport.sendto(
                        encode_quic_retry(
                            version=header.version,
                            source_cid=source_cid,
                            destination_cid=header.source_cid,
                            original_destination_cid=header.destination_cid,
                            retry_token=self._retry.create_token(
                                addr, header.destination_cid, source_cid),
                        ),
                        addr,
                    )
                    return
                else:
                    # validate retry token
                    try:
                        (original_destination_cid, retry_source_connection_id
                         ) = self._retry.validate_token(addr, header.token)
                    except ValueError:
                        return
            else:
                original_destination_connection_id = header.destination_cid

            # create new connection
            connection = QuicConnection(
                configuration=self._configuration,
                original_destination_connection_id=
                original_destination_connection_id,
                retry_source_connection_id=retry_source_connection_id,
                session_ticket_handler=self._session_ticket_handler,
                session_ticket_fetcher=self._session_ticket_fetcher,
            )

            # initiate the QuicSocketFactory class with the below call.
            protocol = self._create_protocol(
                connection, stream_handler=self._stream_handler)
            protocol.connection_made(self._transport)

            # register callbacks
            protocol._connection_id_issued_handler = partial(
                self._connection_id_issued, protocol=protocol)
            protocol._connection_id_retired_handler = partial(
                self._connection_id_retired, protocol=protocol)
            protocol._connection_terminated_handler = partial(
                self._connection_terminated, protocol=protocol)

            self._protocols[header.destination_cid] = protocol
            self._protocols[connection.host_cid] = protocol

        if protocol is not None:
            protocol.datagram_received(data, addr)
Exemplo n.º 13
0
class http3_request_handler(http2_handler.http2_request_handler):
    producer_class = http3_producer
    stateless_retry = True
    conns = {}
    errno = ErrorCode

    def __init__(self, handler, request):
        self._default_varialbes(handler, request)
        self.quic = None  # QUIC protocol
        self.conn = None  # HTTP3 Protocol

        self._push_map = {}
        self._retry = QuicRetryTokenHandler() if self.stateless_retry else None

    def _initiate_shutdown(self):
        # pahse I: send goaway
        with self._plock:
            self.conn.shutdown(self._shutdown_reason[-1])

    def _proceed_shutdown(self):
        # phase II: close quic
        if self.quic:
            errcode, msg, _ = self._shutdown_reason
            with self._plock:
                self.quic.close(errcode, reason_phrase=msg or '')
            self.send_data()

    def _make_connection(self, channel, data):
        ctx = channel.server.ctx

        buf = Buffer(data=data)
        header = pull_quic_header(buf,
                                  host_cid_length=ctx.connection_id_length)
        # version negotiation
        if header.version is not None and header.version not in ctx.supported_versions:
            self.channel.push(
                encode_quic_version_negotiation(
                    source_cid=header.destination_cid,
                    destination_cid=header.source_cid,
                    supported_versions=ctx.supported_versions,
                ))
            return

        conn = self.conns.get(header.destination_cid)
        if conn:
            conn._linked_channel.close()
            conn._linked_channel = channel
            self.quic = conn._quic
            self.conn = conn
            return

        if header.packet_type != PACKET_TYPE_INITIAL or len(data) < 1200:
            return

        original_connection_id = None
        if self._retry is not None:
            if not header.token:
                # create a retry token
                channel.push(
                    encode_quic_retry(
                        version=header.version,
                        source_cid=os.urandom(8),
                        destination_cid=header.source_cid,
                        original_destination_cid=header.destination_cid,
                        retry_token=self._retry.create_token(
                            channel.addr, header.destination_cid)))
                return
            else:
                try:
                    original_connection_id = self._retry.validate_token(
                        channel.addr, header.token)
                except ValueError:
                    return

        self.quic = QuicConnection(
            configuration=ctx,
            logger_connection_id=original_connection_id
            or header.destination_cid,
            original_connection_id=original_connection_id,
            session_ticket_fetcher=channel.server.ticket_store.pop,
            session_ticket_handler=channel.server.ticket_store.add)
        self.conn = H3Connection(self.quic)
        self.conn._linked_channel = channel

    def _handle_events(self, events):
        for event in events:
            if isinstance(event, HeadersReceived):
                created = self.handle_request(
                    event.stream_id,
                    event.headers,
                    has_data_frame=not event.stream_ended,
                    push_id=event.push_id)
                if created:
                    r = self.get_request(event.stream_id)
                    if event.stream_ended:
                        r.set_stream_ended()

            elif isinstance(event, DataReceived) and event.data:
                r = self.get_request(event.stream_id)
                if not r:
                    self.close(self.errno.INTERNAL_ERROR)
                else:
                    try:
                        r.channel.set_data(event.data, len(event.data))
                    except ValueError:
                        self.close(self.errno.INTERNAL_ERROR)

                    if event.stream_ended:
                        if r.collector:
                            r.channel.handle_read()
                            r.channel.found_terminator()
                        r.set_stream_ended()
                        if r.response.is_done():
                            self.remove_request(event.stream_id)

            elif isinstance(event, PushCanceled):
                with self._clock:
                    try:
                        del self._push_map[event.push_id]
                    except KeyError:
                        pass
                self.remove_push_stream(event.push_id)

        self.send_data()

    def process_quic_events(self):
        while 1:
            with self._plock:
                event = self.quic.next_event()
            if event is None:
                break

            if isinstance(event, events.StreamDataReceived):
                h3_events = self.conn.handle_event(event)
                h3_events and self._handle_events(h3_events)

            elif isinstance(event, events.ConnectionIdIssued):
                self.conns[event.connection_id] = self.conn

            elif isinstance(event, events.ConnectionIdRetired):
                assert self.conns[event.connection_id] == self.conn
                conn = self.conns.pop(event.connection_id)
                conn._linked_channel = None

            elif isinstance(event, events.ConnectionTerminated):
                for cid, conn in list(self.conns.items()):
                    if conn == self.conn:
                        conn._linked_channel = None
                        del self.conns[cid]
                self._terminate_connection()

            elif isinstance(event, events.HandshakeCompleted):
                pass

            elif isinstance(event, events.PingAcknowledged):
                # for now nothing to do, channel will be extened by event time
                pass

    def collect_incoming_data(self, data):
        # print ('collect_incoming_data', self.quic, len (data))
        if self.quic is None:
            self._make_connection(self.channel, data)
            if self.quic is None:
                return
        with self._plock:
            self.quic.receive_datagram(data, self.channel.addr,
                                       time.monotonic())
        self.process_quic_events()
        self.send_data()

    def reset_stream(self, stream_id):
        raise AttributeError(
            'HTTP/3 can cancel for only push stream, use cancel_push (push_id)'
        )

    def remove_stream(self, stream_id):
        raise AttributeError('use remove_push_stream (push_id)')

    def remove_push_stream(self, push_id):
        # received by client and just reset
        with self._clock:
            try:
                stream_id = self._push_map.pop(push_id)
            except KeyError:
                return
        super().remove_stream(stream_id)

    def cancel_push(self, push_id):
        # send to client and reset
        with self._clock:
            try:
                steram_id = self._push_map.pop(push_id)
            except KeyError:
                return  # already done or canceled
        with self._plock:
            self.conn.cancel_push(push_id)
        super().remove_stream(steram_id)
        self.send_data()

    def data_to_send(self):
        if self.quic is None or self._closed:
            return []
        self._data_from_producers()
        with self._plock:
            data_to_send = [
                data for data, addr in self.quic.datagrams_to_send(
                    now=time.monotonic())
            ]
        return data_to_send or self._data_exhausted()

    def pushable(self):
        return self.request_acceptable()

    def _maybe_duplicated(self, stream_id, headers):
        path = None
        for k, v in headers:
            if k[0] != 58: break
            elif k == b':path':
                path = v.decode()
            elif k == b':method' and v != b"GET":
                return False
        assert path, ':path is missing'

        with self._clock:
            push_id = self._pushed_pathes.get(path)
            if push_id is None or push_id not in self._push_map:
                return False

        with self._plock:
            self.conn.send_duplicate_push(stream_id, push_id)
        return True

    def handle_request(self,
                       stream_id,
                       headers,
                       has_data_frame=False,
                       push_id=None):
        if push_id is None:
            with self._clock:
                pushing = len(self._pushed_pathes)
            if pushing and self._maybe_duplicated(stream_id, headers):
                return False
        return super().handle_request(stream_id, headers, has_data_frame)

    def push_promise(self, stream_id, request_headers,
                     addtional_request_headers):
        name_, path = request_headers[0]
        assert name_ == ':path', ':path header missing'
        headers = [(k.encode(), v.encode())
                   for k, v in request_headers + addtional_request_headers]
        try:
            promise_stream_id = self.conn.send_push_promise(
                stream_id=stream_id, headers=headers)
            try:
                push_id = self.conn.get_push_id(promise_stream_id)
            except AttributeError:
                push_id = self.conn._next_push_id - 1
        except NoAvailablePushIDError:
            return
        with self._clock:
            self._push_map[push_id] = promise_stream_id
            self._pushed_pathes[path] = push_id
        self._handle_events([
            HeadersReceived(headers=headers,
                            stream_ended=True,
                            stream_id=promise_stream_id,
                            push_id=push_id)
        ])
Exemplo n.º 14
0
class Connection:
    session_ticket = '/tmp/http3-session-ticket.pik'
    socket_timeout = 1
    def __init__ (self, addr, enable_push = True):
        # prepare configuration
        self.netloc = addr
        try:
            host, port = addr.split (":", 1)
            port = int (port)
        except ValueError:
            host, port = addr, 443

        self.addr = (host, port)
        self.configuration = QuicConfiguration(is_client = True, alpn_protocols = H3_ALPN)
        self.configuration.load_verify_locations(os.path.join (os.path.dirname (__file__), 'pycacert.pem'))
        self.configuration.verify_mode = ssl.CERT_NONE
        self.load_session_ticket ()
        self.socket = socket.socket (socket.AF_INET, socket.SOCK_DGRAM)
        self.socket.settimeout (self.socket_timeout)
        self._connected = False
        self._closed = False
        self._history = []
        self._allow_push = enable_push
        self._response = None

    def connect (self):
        host, port = self.addr
        try:
            ipaddress.ip_address(host)
            server_name = None
        except ValueError:
            server_name = host
        if server_name is not None:
            self.configuration.server_name = server_name
        self._quic = QuicConnection (
            configuration = self.configuration, session_ticket_handler = self.save_session_ticket
        )
        self._http = H3Connection(self._quic)
        self._quic.connect(self.addr, now=time.monotonic ())
        self.transmit ()

    def close (self):
        self.transmit ()
        self.recv ()
        #self.socket.close ()
        self._closed = True

    def transmit (self):
        dts = self._quic.datagrams_to_send(now=time.monotonic ())
        if not dts:
            return

        for data, addr in dts:
            #print ('<---', len (data), repr (data [:30]))
            sent = self.socket.sendto (data, self.addr)
        self.recv ()

    def recv (self):
        while 1:
            try:
                data, addr = self.socket.recvfrom (4096)
            except socket.timeout:
                break
            #print ('--->', len (data), repr (data [:30]))
            self._quic.receive_datagram(data, addr, now = time.monotonic ())
            self._process_events()
        self.transmit ()

    def save_session_ticket (self, ticket):
        with open(self.session_ticket, "wb") as fp:
            pickle.dump(ticket, fp)

    def load_session_ticket (self):
        try:
            with open(self.session_ticket, "rb") as fp:
                self.configuration.session_ticket = pickle.load(fp)
        except FileNotFoundError:
            pass

    def _process_events(self):
        event = self._quic.next_event()
        while event is not None:
            if isinstance(event, events.ConnectionTerminated):
                self._connected = False
                self.close ()
            elif isinstance(event, events.HandshakeCompleted):
                self._connected = True
            elif isinstance(event, events.PingAcknowledged):
                pass
            self.quic_event_received(event)
            event = self._quic.next_event()

    def http_event_received (self, http_event):
        if isinstance(http_event, HeadersReceived):
            self._response.headers = {k: v for k, v in http_event.headers}
        elif isinstance(http_event, DataReceived):
            if http_event.stream_id % 4 == 0:
                self._response.data += http_event.data
        elif isinstance(http_event, PushPromiseReceived):
            push_headers = {}
            for k, v in http_event.headers:
                push_headers [k.decode ()] = v.decode ()
            self._response.promises.append (push_headers)

    def quic_event_received (self, event):
        #  pass event to the HTTP layer
        if self._response and not isinstance (event, events.StreamDataReceived):
            self._response.events.append (event)

        for http_event in self._http.handle_event(event):
            if not isinstance (http_event, (HeadersReceived, PushPromiseReceived, DataReceived)):
                # logging only control frames
                self._response.events.append (http_event)
            self.http_event_received(http_event)

    def handle_request (self, request):
        if self._closed:
            raise ConnectionClosed

        if not self._connected:
            self.connect ()

        self._response = request.response
        stream_id = self._quic.get_next_available_stream_id()
        self._response.stream_id = stream_id
        self._http.send_headers(
            stream_id=stream_id,
            headers = [
                (b":method", request.method.encode("utf8")),
                (b":scheme", request.url.scheme.encode("utf8")),
                (b":authority", request.url.authority.encode("utf8")),
                (b":path", request.url.full_path.encode("utf8")),
                (b"user-agent", b"aioquic"),
            ]
            + [
                (k.encode("utf8"), v.encode("utf8"))
                for (k, v) in request.headers.items()
            ],
        )
        self._http.send_data (stream_id=stream_id, data=request.content, end_stream=True)
        self.transmit()

        self._response = None
        return request.response

    def get (self, url, headers = {}):
        req = HttpRequest ('GET', 'https://{}{}'.format (self.netloc, url), b'', headers)
        return self.handle_request (req)

    def post (self, url, data, headers = {}):
        req = HttpRequest ('POST', 'https://{}{}'.format (self.netloc, url), data, headers)
        return self.handle_request (req)

    def request (self, method, url, data = b'', headers = {}):
        req = HttpRequest (method, 'https://{}{}'.format (self.netloc, url), data, headers)
        return self.handle_request (req)