async def handle(self, event: Event) -> None: if isinstance(event, RawData): try: header = pull_quic_header(Buffer(data=event.data), host_cid_length=8) except ValueError: return if ( header.version is not None and header.version not in self.quic_config.supported_versions ): data = encode_quic_version_negotiation( source_cid=header.destination_cid, destination_cid=header.source_cid, supported_versions=self.quic_config.supported_versions, ) await self.send(RawData(data=data, address=event.address)) return connection = self.connections.get(header.destination_cid) if ( connection is None and len(event.data) >= 1200 and header.packet_type == PACKET_TYPE_INITIAL ): connection = QuicConnection( configuration=self.quic_config, original_connection_id=None ) self.connections[header.destination_cid] = connection self.connections[connection.host_cid] = connection if connection is not None: connection.receive_datagram(event.data, event.address, now=self.now()) await self._handle_events(connection, event.address) elif isinstance(event, Closed): pass
class http3_request_handler(http2_handler.http2_request_handler): producer_class = http3_producer stateless_retry = True conns = {} errno = ErrorCode def __init__(self, handler, request): self._default_varialbes(handler, request) self.quic = None # QUIC protocol self.conn = None # HTTP3 Protocol self._push_map = {} self._retry = QuicRetryTokenHandler() if self.stateless_retry else None def _initiate_shutdown(self): # pahse I: send goaway with self._plock: self.conn.shutdown(self._shutdown_reason[-1]) def _proceed_shutdown(self): # phase II: close quic if self.quic: errcode, msg, _ = self._shutdown_reason with self._plock: self.quic.close(errcode, reason_phrase=msg or '') self.send_data() def _make_connection(self, channel, data): ctx = channel.server.ctx buf = Buffer(data=data) header = pull_quic_header(buf, host_cid_length=ctx.connection_id_length) # version negotiation if header.version is not None and header.version not in ctx.supported_versions: self.channel.push( encode_quic_version_negotiation( source_cid=header.destination_cid, destination_cid=header.source_cid, supported_versions=ctx.supported_versions, )) return conn = self.conns.get(header.destination_cid) if conn: conn._linked_channel.close() conn._linked_channel = channel self.quic = conn._quic self.conn = conn return if header.packet_type != PACKET_TYPE_INITIAL or len(data) < 1200: return original_connection_id = None if self._retry is not None: if not header.token: # create a retry token channel.push( encode_quic_retry( version=header.version, source_cid=os.urandom(8), destination_cid=header.source_cid, original_destination_cid=header.destination_cid, retry_token=self._retry.create_token( channel.addr, header.destination_cid))) return else: try: original_connection_id = self._retry.validate_token( channel.addr, header.token) except ValueError: return self.quic = QuicConnection( configuration=ctx, logger_connection_id=original_connection_id or header.destination_cid, original_connection_id=original_connection_id, session_ticket_fetcher=channel.server.ticket_store.pop, session_ticket_handler=channel.server.ticket_store.add) self.conn = H3Connection(self.quic) self.conn._linked_channel = channel def _handle_events(self, events): for event in events: if isinstance(event, HeadersReceived): created = self.handle_request( event.stream_id, event.headers, has_data_frame=not event.stream_ended, push_id=event.push_id) if created: r = self.get_request(event.stream_id) if event.stream_ended: r.set_stream_ended() elif isinstance(event, DataReceived) and event.data: r = self.get_request(event.stream_id) if not r: self.close(self.errno.INTERNAL_ERROR) else: try: r.channel.set_data(event.data, len(event.data)) except ValueError: self.close(self.errno.INTERNAL_ERROR) if event.stream_ended: if r.collector: r.channel.handle_read() r.channel.found_terminator() r.set_stream_ended() if r.response.is_done(): self.remove_request(event.stream_id) elif isinstance(event, PushCanceled): with self._clock: try: del self._push_map[event.push_id] except KeyError: pass self.remove_push_stream(event.push_id) self.send_data() def process_quic_events(self): while 1: with self._plock: event = self.quic.next_event() if event is None: break if isinstance(event, events.StreamDataReceived): h3_events = self.conn.handle_event(event) h3_events and self._handle_events(h3_events) elif isinstance(event, events.ConnectionIdIssued): self.conns[event.connection_id] = self.conn elif isinstance(event, events.ConnectionIdRetired): assert self.conns[event.connection_id] == self.conn conn = self.conns.pop(event.connection_id) conn._linked_channel = None elif isinstance(event, events.ConnectionTerminated): for cid, conn in list(self.conns.items()): if conn == self.conn: conn._linked_channel = None del self.conns[cid] self._terminate_connection() elif isinstance(event, events.HandshakeCompleted): pass elif isinstance(event, events.PingAcknowledged): # for now nothing to do, channel will be extened by event time pass def collect_incoming_data(self, data): # print ('collect_incoming_data', self.quic, len (data)) if self.quic is None: self._make_connection(self.channel, data) if self.quic is None: return with self._plock: self.quic.receive_datagram(data, self.channel.addr, time.monotonic()) self.process_quic_events() self.send_data() def reset_stream(self, stream_id): raise AttributeError( 'HTTP/3 can cancel for only push stream, use cancel_push (push_id)' ) def remove_stream(self, stream_id): raise AttributeError('use remove_push_stream (push_id)') def remove_push_stream(self, push_id): # received by client and just reset with self._clock: try: stream_id = self._push_map.pop(push_id) except KeyError: return super().remove_stream(stream_id) def cancel_push(self, push_id): # send to client and reset with self._clock: try: steram_id = self._push_map.pop(push_id) except KeyError: return # already done or canceled with self._plock: self.conn.cancel_push(push_id) super().remove_stream(steram_id) self.send_data() def data_to_send(self): if self.quic is None or self._closed: return [] self._data_from_producers() with self._plock: data_to_send = [ data for data, addr in self.quic.datagrams_to_send( now=time.monotonic()) ] return data_to_send or self._data_exhausted() def pushable(self): return self.request_acceptable() def _maybe_duplicated(self, stream_id, headers): path = None for k, v in headers: if k[0] != 58: break elif k == b':path': path = v.decode() elif k == b':method' and v != b"GET": return False assert path, ':path is missing' with self._clock: push_id = self._pushed_pathes.get(path) if push_id is None or push_id not in self._push_map: return False with self._plock: self.conn.send_duplicate_push(stream_id, push_id) return True def handle_request(self, stream_id, headers, has_data_frame=False, push_id=None): if push_id is None: with self._clock: pushing = len(self._pushed_pathes) if pushing and self._maybe_duplicated(stream_id, headers): return False return super().handle_request(stream_id, headers, has_data_frame) def push_promise(self, stream_id, request_headers, addtional_request_headers): name_, path = request_headers[0] assert name_ == ':path', ':path header missing' headers = [(k.encode(), v.encode()) for k, v in request_headers + addtional_request_headers] try: promise_stream_id = self.conn.send_push_promise( stream_id=stream_id, headers=headers) try: push_id = self.conn.get_push_id(promise_stream_id) except AttributeError: push_id = self.conn._next_push_id - 1 except NoAvailablePushIDError: return with self._clock: self._push_map[push_id] = promise_stream_id self._pushed_pathes[path] = push_id self._handle_events([ HeadersReceived(headers=headers, stream_ended=True, stream_id=promise_stream_id, push_id=push_id) ])
class Connection: session_ticket = '/tmp/http3-session-ticket.pik' socket_timeout = 1 def __init__ (self, addr, enable_push = True): # prepare configuration self.netloc = addr try: host, port = addr.split (":", 1) port = int (port) except ValueError: host, port = addr, 443 self.addr = (host, port) self.configuration = QuicConfiguration(is_client = True, alpn_protocols = H3_ALPN) self.configuration.load_verify_locations(os.path.join (os.path.dirname (__file__), 'pycacert.pem')) self.configuration.verify_mode = ssl.CERT_NONE self.load_session_ticket () self.socket = socket.socket (socket.AF_INET, socket.SOCK_DGRAM) self.socket.settimeout (self.socket_timeout) self._connected = False self._closed = False self._history = [] self._allow_push = enable_push self._response = None def connect (self): host, port = self.addr try: ipaddress.ip_address(host) server_name = None except ValueError: server_name = host if server_name is not None: self.configuration.server_name = server_name self._quic = QuicConnection ( configuration = self.configuration, session_ticket_handler = self.save_session_ticket ) self._http = H3Connection(self._quic) self._quic.connect(self.addr, now=time.monotonic ()) self.transmit () def close (self): self.transmit () self.recv () #self.socket.close () self._closed = True def transmit (self): dts = self._quic.datagrams_to_send(now=time.monotonic ()) if not dts: return for data, addr in dts: #print ('<---', len (data), repr (data [:30])) sent = self.socket.sendto (data, self.addr) self.recv () def recv (self): while 1: try: data, addr = self.socket.recvfrom (4096) except socket.timeout: break #print ('--->', len (data), repr (data [:30])) self._quic.receive_datagram(data, addr, now = time.monotonic ()) self._process_events() self.transmit () def save_session_ticket (self, ticket): with open(self.session_ticket, "wb") as fp: pickle.dump(ticket, fp) def load_session_ticket (self): try: with open(self.session_ticket, "rb") as fp: self.configuration.session_ticket = pickle.load(fp) except FileNotFoundError: pass def _process_events(self): event = self._quic.next_event() while event is not None: if isinstance(event, events.ConnectionTerminated): self._connected = False self.close () elif isinstance(event, events.HandshakeCompleted): self._connected = True elif isinstance(event, events.PingAcknowledged): pass self.quic_event_received(event) event = self._quic.next_event() def http_event_received (self, http_event): if isinstance(http_event, HeadersReceived): self._response.headers = {k: v for k, v in http_event.headers} elif isinstance(http_event, DataReceived): if http_event.stream_id % 4 == 0: self._response.data += http_event.data elif isinstance(http_event, PushPromiseReceived): push_headers = {} for k, v in http_event.headers: push_headers [k.decode ()] = v.decode () self._response.promises.append (push_headers) def quic_event_received (self, event): # pass event to the HTTP layer if self._response and not isinstance (event, events.StreamDataReceived): self._response.events.append (event) for http_event in self._http.handle_event(event): if not isinstance (http_event, (HeadersReceived, PushPromiseReceived, DataReceived)): # logging only control frames self._response.events.append (http_event) self.http_event_received(http_event) def handle_request (self, request): if self._closed: raise ConnectionClosed if not self._connected: self.connect () self._response = request.response stream_id = self._quic.get_next_available_stream_id() self._response.stream_id = stream_id self._http.send_headers( stream_id=stream_id, headers = [ (b":method", request.method.encode("utf8")), (b":scheme", request.url.scheme.encode("utf8")), (b":authority", request.url.authority.encode("utf8")), (b":path", request.url.full_path.encode("utf8")), (b"user-agent", b"aioquic"), ] + [ (k.encode("utf8"), v.encode("utf8")) for (k, v) in request.headers.items() ], ) self._http.send_data (stream_id=stream_id, data=request.content, end_stream=True) self.transmit() self._response = None return request.response def get (self, url, headers = {}): req = HttpRequest ('GET', 'https://{}{}'.format (self.netloc, url), b'', headers) return self.handle_request (req) def post (self, url, data, headers = {}): req = HttpRequest ('POST', 'https://{}{}'.format (self.netloc, url), data, headers) return self.handle_request (req) def request (self, method, url, data = b'', headers = {}): req = HttpRequest (method, 'https://{}{}'.format (self.netloc, url), data, headers) return self.handle_request (req)