Exemple #1
0
    async def handle(self, event: Event) -> None:
        if isinstance(event, RawData):
            try:
                header = pull_quic_header(Buffer(data=event.data), host_cid_length=8)
            except ValueError:
                return
            if (
                header.version is not None
                and header.version not in self.quic_config.supported_versions
            ):
                data = encode_quic_version_negotiation(
                    source_cid=header.destination_cid,
                    destination_cid=header.source_cid,
                    supported_versions=self.quic_config.supported_versions,
                )
                await self.send(RawData(data=data, address=event.address))
                return

            connection = self.connections.get(header.destination_cid)
            if (
                connection is None
                and len(event.data) >= 1200
                and header.packet_type == PACKET_TYPE_INITIAL
            ):
                connection = QuicConnection(
                    configuration=self.quic_config, original_connection_id=None
                )
                self.connections[header.destination_cid] = connection
                self.connections[connection.host_cid] = connection

            if connection is not None:
                connection.receive_datagram(event.data, event.address, now=self.now())
                await self._handle_events(connection, event.address)
        elif isinstance(event, Closed):
            pass
Exemple #2
0
class http3_request_handler(http2_handler.http2_request_handler):
    producer_class = http3_producer
    stateless_retry = True
    conns = {}
    errno = ErrorCode

    def __init__(self, handler, request):
        self._default_varialbes(handler, request)
        self.quic = None  # QUIC protocol
        self.conn = None  # HTTP3 Protocol

        self._push_map = {}
        self._retry = QuicRetryTokenHandler() if self.stateless_retry else None

    def _initiate_shutdown(self):
        # pahse I: send goaway
        with self._plock:
            self.conn.shutdown(self._shutdown_reason[-1])

    def _proceed_shutdown(self):
        # phase II: close quic
        if self.quic:
            errcode, msg, _ = self._shutdown_reason
            with self._plock:
                self.quic.close(errcode, reason_phrase=msg or '')
            self.send_data()

    def _make_connection(self, channel, data):
        ctx = channel.server.ctx

        buf = Buffer(data=data)
        header = pull_quic_header(buf,
                                  host_cid_length=ctx.connection_id_length)
        # version negotiation
        if header.version is not None and header.version not in ctx.supported_versions:
            self.channel.push(
                encode_quic_version_negotiation(
                    source_cid=header.destination_cid,
                    destination_cid=header.source_cid,
                    supported_versions=ctx.supported_versions,
                ))
            return

        conn = self.conns.get(header.destination_cid)
        if conn:
            conn._linked_channel.close()
            conn._linked_channel = channel
            self.quic = conn._quic
            self.conn = conn
            return

        if header.packet_type != PACKET_TYPE_INITIAL or len(data) < 1200:
            return

        original_connection_id = None
        if self._retry is not None:
            if not header.token:
                # create a retry token
                channel.push(
                    encode_quic_retry(
                        version=header.version,
                        source_cid=os.urandom(8),
                        destination_cid=header.source_cid,
                        original_destination_cid=header.destination_cid,
                        retry_token=self._retry.create_token(
                            channel.addr, header.destination_cid)))
                return
            else:
                try:
                    original_connection_id = self._retry.validate_token(
                        channel.addr, header.token)
                except ValueError:
                    return

        self.quic = QuicConnection(
            configuration=ctx,
            logger_connection_id=original_connection_id
            or header.destination_cid,
            original_connection_id=original_connection_id,
            session_ticket_fetcher=channel.server.ticket_store.pop,
            session_ticket_handler=channel.server.ticket_store.add)
        self.conn = H3Connection(self.quic)
        self.conn._linked_channel = channel

    def _handle_events(self, events):
        for event in events:
            if isinstance(event, HeadersReceived):
                created = self.handle_request(
                    event.stream_id,
                    event.headers,
                    has_data_frame=not event.stream_ended,
                    push_id=event.push_id)
                if created:
                    r = self.get_request(event.stream_id)
                    if event.stream_ended:
                        r.set_stream_ended()

            elif isinstance(event, DataReceived) and event.data:
                r = self.get_request(event.stream_id)
                if not r:
                    self.close(self.errno.INTERNAL_ERROR)
                else:
                    try:
                        r.channel.set_data(event.data, len(event.data))
                    except ValueError:
                        self.close(self.errno.INTERNAL_ERROR)

                    if event.stream_ended:
                        if r.collector:
                            r.channel.handle_read()
                            r.channel.found_terminator()
                        r.set_stream_ended()
                        if r.response.is_done():
                            self.remove_request(event.stream_id)

            elif isinstance(event, PushCanceled):
                with self._clock:
                    try:
                        del self._push_map[event.push_id]
                    except KeyError:
                        pass
                self.remove_push_stream(event.push_id)

        self.send_data()

    def process_quic_events(self):
        while 1:
            with self._plock:
                event = self.quic.next_event()
            if event is None:
                break

            if isinstance(event, events.StreamDataReceived):
                h3_events = self.conn.handle_event(event)
                h3_events and self._handle_events(h3_events)

            elif isinstance(event, events.ConnectionIdIssued):
                self.conns[event.connection_id] = self.conn

            elif isinstance(event, events.ConnectionIdRetired):
                assert self.conns[event.connection_id] == self.conn
                conn = self.conns.pop(event.connection_id)
                conn._linked_channel = None

            elif isinstance(event, events.ConnectionTerminated):
                for cid, conn in list(self.conns.items()):
                    if conn == self.conn:
                        conn._linked_channel = None
                        del self.conns[cid]
                self._terminate_connection()

            elif isinstance(event, events.HandshakeCompleted):
                pass

            elif isinstance(event, events.PingAcknowledged):
                # for now nothing to do, channel will be extened by event time
                pass

    def collect_incoming_data(self, data):
        # print ('collect_incoming_data', self.quic, len (data))
        if self.quic is None:
            self._make_connection(self.channel, data)
            if self.quic is None:
                return
        with self._plock:
            self.quic.receive_datagram(data, self.channel.addr,
                                       time.monotonic())
        self.process_quic_events()
        self.send_data()

    def reset_stream(self, stream_id):
        raise AttributeError(
            'HTTP/3 can cancel for only push stream, use cancel_push (push_id)'
        )

    def remove_stream(self, stream_id):
        raise AttributeError('use remove_push_stream (push_id)')

    def remove_push_stream(self, push_id):
        # received by client and just reset
        with self._clock:
            try:
                stream_id = self._push_map.pop(push_id)
            except KeyError:
                return
        super().remove_stream(stream_id)

    def cancel_push(self, push_id):
        # send to client and reset
        with self._clock:
            try:
                steram_id = self._push_map.pop(push_id)
            except KeyError:
                return  # already done or canceled
        with self._plock:
            self.conn.cancel_push(push_id)
        super().remove_stream(steram_id)
        self.send_data()

    def data_to_send(self):
        if self.quic is None or self._closed:
            return []
        self._data_from_producers()
        with self._plock:
            data_to_send = [
                data for data, addr in self.quic.datagrams_to_send(
                    now=time.monotonic())
            ]
        return data_to_send or self._data_exhausted()

    def pushable(self):
        return self.request_acceptable()

    def _maybe_duplicated(self, stream_id, headers):
        path = None
        for k, v in headers:
            if k[0] != 58: break
            elif k == b':path':
                path = v.decode()
            elif k == b':method' and v != b"GET":
                return False
        assert path, ':path is missing'

        with self._clock:
            push_id = self._pushed_pathes.get(path)
            if push_id is None or push_id not in self._push_map:
                return False

        with self._plock:
            self.conn.send_duplicate_push(stream_id, push_id)
        return True

    def handle_request(self,
                       stream_id,
                       headers,
                       has_data_frame=False,
                       push_id=None):
        if push_id is None:
            with self._clock:
                pushing = len(self._pushed_pathes)
            if pushing and self._maybe_duplicated(stream_id, headers):
                return False
        return super().handle_request(stream_id, headers, has_data_frame)

    def push_promise(self, stream_id, request_headers,
                     addtional_request_headers):
        name_, path = request_headers[0]
        assert name_ == ':path', ':path header missing'
        headers = [(k.encode(), v.encode())
                   for k, v in request_headers + addtional_request_headers]
        try:
            promise_stream_id = self.conn.send_push_promise(
                stream_id=stream_id, headers=headers)
            try:
                push_id = self.conn.get_push_id(promise_stream_id)
            except AttributeError:
                push_id = self.conn._next_push_id - 1
        except NoAvailablePushIDError:
            return
        with self._clock:
            self._push_map[push_id] = promise_stream_id
            self._pushed_pathes[path] = push_id
        self._handle_events([
            HeadersReceived(headers=headers,
                            stream_ended=True,
                            stream_id=promise_stream_id,
                            push_id=push_id)
        ])
Exemple #3
0
class Connection:
    session_ticket = '/tmp/http3-session-ticket.pik'
    socket_timeout = 1
    def __init__ (self, addr, enable_push = True):
        # prepare configuration
        self.netloc = addr
        try:
            host, port = addr.split (":", 1)
            port = int (port)
        except ValueError:
            host, port = addr, 443

        self.addr = (host, port)
        self.configuration = QuicConfiguration(is_client = True, alpn_protocols = H3_ALPN)
        self.configuration.load_verify_locations(os.path.join (os.path.dirname (__file__), 'pycacert.pem'))
        self.configuration.verify_mode = ssl.CERT_NONE
        self.load_session_ticket ()
        self.socket = socket.socket (socket.AF_INET, socket.SOCK_DGRAM)
        self.socket.settimeout (self.socket_timeout)
        self._connected = False
        self._closed = False
        self._history = []
        self._allow_push = enable_push
        self._response = None

    def connect (self):
        host, port = self.addr
        try:
            ipaddress.ip_address(host)
            server_name = None
        except ValueError:
            server_name = host
        if server_name is not None:
            self.configuration.server_name = server_name
        self._quic = QuicConnection (
            configuration = self.configuration, session_ticket_handler = self.save_session_ticket
        )
        self._http = H3Connection(self._quic)
        self._quic.connect(self.addr, now=time.monotonic ())
        self.transmit ()

    def close (self):
        self.transmit ()
        self.recv ()
        #self.socket.close ()
        self._closed = True

    def transmit (self):
        dts = self._quic.datagrams_to_send(now=time.monotonic ())
        if not dts:
            return

        for data, addr in dts:
            #print ('<---', len (data), repr (data [:30]))
            sent = self.socket.sendto (data, self.addr)
        self.recv ()

    def recv (self):
        while 1:
            try:
                data, addr = self.socket.recvfrom (4096)
            except socket.timeout:
                break
            #print ('--->', len (data), repr (data [:30]))
            self._quic.receive_datagram(data, addr, now = time.monotonic ())
            self._process_events()
        self.transmit ()

    def save_session_ticket (self, ticket):
        with open(self.session_ticket, "wb") as fp:
            pickle.dump(ticket, fp)

    def load_session_ticket (self):
        try:
            with open(self.session_ticket, "rb") as fp:
                self.configuration.session_ticket = pickle.load(fp)
        except FileNotFoundError:
            pass

    def _process_events(self):
        event = self._quic.next_event()
        while event is not None:
            if isinstance(event, events.ConnectionTerminated):
                self._connected = False
                self.close ()
            elif isinstance(event, events.HandshakeCompleted):
                self._connected = True
            elif isinstance(event, events.PingAcknowledged):
                pass
            self.quic_event_received(event)
            event = self._quic.next_event()

    def http_event_received (self, http_event):
        if isinstance(http_event, HeadersReceived):
            self._response.headers = {k: v for k, v in http_event.headers}
        elif isinstance(http_event, DataReceived):
            if http_event.stream_id % 4 == 0:
                self._response.data += http_event.data
        elif isinstance(http_event, PushPromiseReceived):
            push_headers = {}
            for k, v in http_event.headers:
                push_headers [k.decode ()] = v.decode ()
            self._response.promises.append (push_headers)

    def quic_event_received (self, event):
        #  pass event to the HTTP layer
        if self._response and not isinstance (event, events.StreamDataReceived):
            self._response.events.append (event)

        for http_event in self._http.handle_event(event):
            if not isinstance (http_event, (HeadersReceived, PushPromiseReceived, DataReceived)):
                # logging only control frames
                self._response.events.append (http_event)
            self.http_event_received(http_event)

    def handle_request (self, request):
        if self._closed:
            raise ConnectionClosed

        if not self._connected:
            self.connect ()

        self._response = request.response
        stream_id = self._quic.get_next_available_stream_id()
        self._response.stream_id = stream_id
        self._http.send_headers(
            stream_id=stream_id,
            headers = [
                (b":method", request.method.encode("utf8")),
                (b":scheme", request.url.scheme.encode("utf8")),
                (b":authority", request.url.authority.encode("utf8")),
                (b":path", request.url.full_path.encode("utf8")),
                (b"user-agent", b"aioquic"),
            ]
            + [
                (k.encode("utf8"), v.encode("utf8"))
                for (k, v) in request.headers.items()
            ],
        )
        self._http.send_data (stream_id=stream_id, data=request.content, end_stream=True)
        self.transmit()

        self._response = None
        return request.response

    def get (self, url, headers = {}):
        req = HttpRequest ('GET', 'https://{}{}'.format (self.netloc, url), b'', headers)
        return self.handle_request (req)

    def post (self, url, data, headers = {}):
        req = HttpRequest ('POST', 'https://{}{}'.format (self.netloc, url), data, headers)
        return self.handle_request (req)

    def request (self, method, url, data = b'', headers = {}):
        req = HttpRequest (method, 'https://{}{}'.format (self.netloc, url), data, headers)
        return self.handle_request (req)