async def _handle_events( self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None ) -> None: event = connection.next_event() while event is not None: if isinstance(event, ConnectionTerminated): pass elif isinstance(event, ProtocolNegotiated): self.http_connections[connection] = H3Protocol( self.config, client, self.server, self.spawn_app, connection, partial(self.send_all, connection), ) elif isinstance(event, ConnectionIdIssued): self.connections[event.connection_id] = connection elif isinstance(event, ConnectionIdRetired): del self.connections[event.connection_id] if connection in self.http_connections: await self.http_connections[connection].handle(event) event = connection.next_event() await self.send_all(connection) timer = connection.get_timer() if timer is not None: self.call_at(timer, partial(self._handle_timer, connection))
async def handle(self, event: Event) -> None: if isinstance(event, RawData): try: header = pull_quic_header(Buffer(data=event.data), host_cid_length=8) except ValueError: return if ( header.version is not None and header.version not in self.quic_config.supported_versions ): data = encode_quic_version_negotiation( source_cid=header.destination_cid, destination_cid=header.source_cid, supported_versions=self.quic_config.supported_versions, ) await self.send(RawData(data=data, address=event.address)) return connection = self.connections.get(header.destination_cid) if ( connection is None and len(event.data) >= 1200 and header.packet_type == PACKET_TYPE_INITIAL ): connection = QuicConnection( configuration=self.quic_config, original_connection_id=None ) self.connections[header.destination_cid] = connection self.connections[connection.host_cid] = connection if connection is not None: connection.receive_datagram(event.data, event.address, now=self.now()) await self._handle_events(connection, event.address) elif isinstance(event, Closed): pass
async def _handle_timer(self, timer: float, connection: QuicConnection) -> None: wait = max(0, timer - self.context.time()) await self.context.sleep(wait) if connection._close_at is not None: connection.handle_timer(now=self.context.time()) await self._handle_events(connection, None)
def _make_connection(self, channel, data): ctx = channel.server.ctx buf = Buffer(data=data) header = pull_quic_header(buf, host_cid_length=ctx.connection_id_length) # version negotiation if header.version is not None and header.version not in ctx.supported_versions: self.channel.push( encode_quic_version_negotiation( source_cid=header.destination_cid, destination_cid=header.source_cid, supported_versions=ctx.supported_versions, )) return conn = self.conns.get(header.destination_cid) if conn: conn._linked_channel.close() conn._linked_channel = channel self.quic = conn._quic self.conn = conn return if header.packet_type != PACKET_TYPE_INITIAL or len(data) < 1200: return original_connection_id = None if self._retry is not None: if not header.token: # create a retry token channel.push( encode_quic_retry( version=header.version, source_cid=os.urandom(8), destination_cid=header.source_cid, original_destination_cid=header.destination_cid, retry_token=self._retry.create_token( channel.addr, header.destination_cid))) return else: try: original_connection_id = self._retry.validate_token( channel.addr, header.token) except ValueError: return self.quic = QuicConnection( configuration=ctx, logger_connection_id=original_connection_id or header.destination_cid, original_connection_id=original_connection_id, session_ticket_fetcher=channel.server.ticket_store.pop, session_ticket_handler=channel.server.ticket_store.add) self.conn = H3Connection(self.quic) self.conn._linked_channel = channel
def connect (self): host, port = self.addr try: ipaddress.ip_address(host) server_name = None except ValueError: server_name = host if server_name is not None: self.configuration.server_name = server_name self._quic = QuicConnection ( configuration = self.configuration, session_ticket_handler = self.save_session_ticket ) self._http = H3Connection(self._quic) self._quic.connect(self.addr, now=time.monotonic ()) self.transmit ()
async def connect( host: str, port: int, *, configuration: Optional[QuicConfiguration] = None, create_protocol: Optional[Callable] = QuicFactorySocket, session_ticket_handler: Optional[SessionTicketHandler] = None, stream_handler: Optional[QuicStreamHandler] = None, wait_connected: bool = True, local_port: int = 0, ) -> AsyncGenerator[QuicFactorySocket, None]: loop = asyncio.get_event_loop() local_host = "::" try: ipaddress.ip_address(host) server_name = None except ValueError: server_name = host infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM) addr = infos[0][4] if len(addr) == 2: addr = ("::ffff:" + addr[0], addr[1], 0, 0) #prepare QUIC connection if configuration is None: configuration = QuicConfiguration(is_client=True) if server_name is not None: configuration.server_name = server_name connection = QuicConnection(configuration=configuration, session_ticket_handler=session_ticket_handler) _, protocol = await loop.create_datagram_endpoint( lambda: create_protocol(connection, stream_handler=stream_handler), local_addr=(local_host, local_port)) protocol = cast(QuicFactorySocket, protocol) protocol.connect(addr) if wait_connected: await protocol.wait_connected() try: yield protocol finally: protocol.close() await protocol.wait_closed()
async def connection(self, message): if message.unresolved_remote is None: host = message.opt.uri_host port = message.opt.uri_port or self.default_port if host is None: raise ValueError( "No location found to send message to (neither in .opt.uri_host nor in .remote)" ) else: host, port = util.hostportsplit(message.unresolved_remote) port = port or self.default_port try: ipaddress.ip_address(host) server_name = None except ValueError as ve: server_name = host infos = await self.loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM) self.addr = infos[0][4] config = QuicConfiguration(is_client=True, alpn_protocols='coap', idle_timeout=864000, server_name=server_name) config.verify_mode = ssl.CERT_NONE if config.server_name is None: config.server_name = server_name connection = QuicConnection(configuration=config) self.quic = Quic(connection) self.quic.ctx = self try: transport, protocol = await self.loop.create_datagram_endpoint( lambda: self.quic, remote_addr=(host, port)) protocol.connect(self.addr) await protocol.wait_connected() self.con = True except OSError: raise error.NetworkError("Connection failed to %r" % host) return protocol
async def send_all(self, connection: QuicConnection) -> None: for data, address in connection.datagrams_to_send(now=self.now()): await self.send(RawData(data=data, address=address))
async def _handle_timer(self, connection: QuicConnection) -> None: if connection._close_at is not None: connection.handle_timer(now=self.now()) await self._handle_events(connection, None)
def make_client(): client_configuration = QuicConfiguration(is_client=True) client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) client = QuicConnection(configuration=client_configuration) client._ack_delay = 0 return client
def make_connection(): server_configuration = QuicConfiguration(is_client=False) server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) quic = QuicConnection(configuration=server_configuration) quic._ack_delay = 0 return h3.H3Connection(quic)
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: data = cast(bytes, data) buf = Buffer(data=data) # logger.info("datagram received") global totalDatagrams totalDatagrams += 1 # logger.info('total:{} {}'.format(totalDatagrams, addr)) try: header = pull_quic_header( buf, host_cid_length=self._configuration.connection_id_length) except ValueError: return # version negotiation if (header.version is not None and header.version not in self._configuration.supported_versions): self._transport.sendto( encode_quic_version_negotiation( source_cid=header.destination_cid, destination_cid=header.source_cid, supported_versions=self._configuration.supported_versions, ), addr, ) return protocol = self._protocols.get(header.destination_cid, None) original_destination_connection_id: Optional[bytes] = None retry_source_connection_id: Optional[bytes] = None if (protocol is None and len(data) >= 1200 and header.packet_type == PACKET_TYPE_INITIAL): #retry if self._retry is not None: if not header.token: # create a retry token source_cid = os.urandom(8) self._transport.sendto( encode_quic_retry( version=header.version, source_cid=source_cid, destination_cid=header.source_cid, original_destination_cid=header.destination_cid, retry_token=self._retry.create_token( addr, header.destination_cid, source_cid), ), addr, ) return else: # validate retry token try: (original_destination_cid, retry_source_connection_id ) = self._retry.validate_token(addr, header.token) except ValueError: return else: original_destination_connection_id = header.destination_cid # create new connection connection = QuicConnection( configuration=self._configuration, original_destination_connection_id= original_destination_connection_id, retry_source_connection_id=retry_source_connection_id, session_ticket_handler=self._session_ticket_handler, session_ticket_fetcher=self._session_ticket_fetcher, ) # initiate the QuicSocketFactory class with the below call. protocol = self._create_protocol( connection, stream_handler=self._stream_handler) protocol.connection_made(self._transport) # register callbacks protocol._connection_id_issued_handler = partial( self._connection_id_issued, protocol=protocol) protocol._connection_id_retired_handler = partial( self._connection_id_retired, protocol=protocol) protocol._connection_terminated_handler = partial( self._connection_terminated, protocol=protocol) self._protocols[header.destination_cid] = protocol self._protocols[connection.host_cid] = protocol if protocol is not None: protocol.datagram_received(data, addr)
class http3_request_handler(http2_handler.http2_request_handler): producer_class = http3_producer stateless_retry = True conns = {} errno = ErrorCode def __init__(self, handler, request): self._default_varialbes(handler, request) self.quic = None # QUIC protocol self.conn = None # HTTP3 Protocol self._push_map = {} self._retry = QuicRetryTokenHandler() if self.stateless_retry else None def _initiate_shutdown(self): # pahse I: send goaway with self._plock: self.conn.shutdown(self._shutdown_reason[-1]) def _proceed_shutdown(self): # phase II: close quic if self.quic: errcode, msg, _ = self._shutdown_reason with self._plock: self.quic.close(errcode, reason_phrase=msg or '') self.send_data() def _make_connection(self, channel, data): ctx = channel.server.ctx buf = Buffer(data=data) header = pull_quic_header(buf, host_cid_length=ctx.connection_id_length) # version negotiation if header.version is not None and header.version not in ctx.supported_versions: self.channel.push( encode_quic_version_negotiation( source_cid=header.destination_cid, destination_cid=header.source_cid, supported_versions=ctx.supported_versions, )) return conn = self.conns.get(header.destination_cid) if conn: conn._linked_channel.close() conn._linked_channel = channel self.quic = conn._quic self.conn = conn return if header.packet_type != PACKET_TYPE_INITIAL or len(data) < 1200: return original_connection_id = None if self._retry is not None: if not header.token: # create a retry token channel.push( encode_quic_retry( version=header.version, source_cid=os.urandom(8), destination_cid=header.source_cid, original_destination_cid=header.destination_cid, retry_token=self._retry.create_token( channel.addr, header.destination_cid))) return else: try: original_connection_id = self._retry.validate_token( channel.addr, header.token) except ValueError: return self.quic = QuicConnection( configuration=ctx, logger_connection_id=original_connection_id or header.destination_cid, original_connection_id=original_connection_id, session_ticket_fetcher=channel.server.ticket_store.pop, session_ticket_handler=channel.server.ticket_store.add) self.conn = H3Connection(self.quic) self.conn._linked_channel = channel def _handle_events(self, events): for event in events: if isinstance(event, HeadersReceived): created = self.handle_request( event.stream_id, event.headers, has_data_frame=not event.stream_ended, push_id=event.push_id) if created: r = self.get_request(event.stream_id) if event.stream_ended: r.set_stream_ended() elif isinstance(event, DataReceived) and event.data: r = self.get_request(event.stream_id) if not r: self.close(self.errno.INTERNAL_ERROR) else: try: r.channel.set_data(event.data, len(event.data)) except ValueError: self.close(self.errno.INTERNAL_ERROR) if event.stream_ended: if r.collector: r.channel.handle_read() r.channel.found_terminator() r.set_stream_ended() if r.response.is_done(): self.remove_request(event.stream_id) elif isinstance(event, PushCanceled): with self._clock: try: del self._push_map[event.push_id] except KeyError: pass self.remove_push_stream(event.push_id) self.send_data() def process_quic_events(self): while 1: with self._plock: event = self.quic.next_event() if event is None: break if isinstance(event, events.StreamDataReceived): h3_events = self.conn.handle_event(event) h3_events and self._handle_events(h3_events) elif isinstance(event, events.ConnectionIdIssued): self.conns[event.connection_id] = self.conn elif isinstance(event, events.ConnectionIdRetired): assert self.conns[event.connection_id] == self.conn conn = self.conns.pop(event.connection_id) conn._linked_channel = None elif isinstance(event, events.ConnectionTerminated): for cid, conn in list(self.conns.items()): if conn == self.conn: conn._linked_channel = None del self.conns[cid] self._terminate_connection() elif isinstance(event, events.HandshakeCompleted): pass elif isinstance(event, events.PingAcknowledged): # for now nothing to do, channel will be extened by event time pass def collect_incoming_data(self, data): # print ('collect_incoming_data', self.quic, len (data)) if self.quic is None: self._make_connection(self.channel, data) if self.quic is None: return with self._plock: self.quic.receive_datagram(data, self.channel.addr, time.monotonic()) self.process_quic_events() self.send_data() def reset_stream(self, stream_id): raise AttributeError( 'HTTP/3 can cancel for only push stream, use cancel_push (push_id)' ) def remove_stream(self, stream_id): raise AttributeError('use remove_push_stream (push_id)') def remove_push_stream(self, push_id): # received by client and just reset with self._clock: try: stream_id = self._push_map.pop(push_id) except KeyError: return super().remove_stream(stream_id) def cancel_push(self, push_id): # send to client and reset with self._clock: try: steram_id = self._push_map.pop(push_id) except KeyError: return # already done or canceled with self._plock: self.conn.cancel_push(push_id) super().remove_stream(steram_id) self.send_data() def data_to_send(self): if self.quic is None or self._closed: return [] self._data_from_producers() with self._plock: data_to_send = [ data for data, addr in self.quic.datagrams_to_send( now=time.monotonic()) ] return data_to_send or self._data_exhausted() def pushable(self): return self.request_acceptable() def _maybe_duplicated(self, stream_id, headers): path = None for k, v in headers: if k[0] != 58: break elif k == b':path': path = v.decode() elif k == b':method' and v != b"GET": return False assert path, ':path is missing' with self._clock: push_id = self._pushed_pathes.get(path) if push_id is None or push_id not in self._push_map: return False with self._plock: self.conn.send_duplicate_push(stream_id, push_id) return True def handle_request(self, stream_id, headers, has_data_frame=False, push_id=None): if push_id is None: with self._clock: pushing = len(self._pushed_pathes) if pushing and self._maybe_duplicated(stream_id, headers): return False return super().handle_request(stream_id, headers, has_data_frame) def push_promise(self, stream_id, request_headers, addtional_request_headers): name_, path = request_headers[0] assert name_ == ':path', ':path header missing' headers = [(k.encode(), v.encode()) for k, v in request_headers + addtional_request_headers] try: promise_stream_id = self.conn.send_push_promise( stream_id=stream_id, headers=headers) try: push_id = self.conn.get_push_id(promise_stream_id) except AttributeError: push_id = self.conn._next_push_id - 1 except NoAvailablePushIDError: return with self._clock: self._push_map[push_id] = promise_stream_id self._pushed_pathes[path] = push_id self._handle_events([ HeadersReceived(headers=headers, stream_ended=True, stream_id=promise_stream_id, push_id=push_id) ])
class Connection: session_ticket = '/tmp/http3-session-ticket.pik' socket_timeout = 1 def __init__ (self, addr, enable_push = True): # prepare configuration self.netloc = addr try: host, port = addr.split (":", 1) port = int (port) except ValueError: host, port = addr, 443 self.addr = (host, port) self.configuration = QuicConfiguration(is_client = True, alpn_protocols = H3_ALPN) self.configuration.load_verify_locations(os.path.join (os.path.dirname (__file__), 'pycacert.pem')) self.configuration.verify_mode = ssl.CERT_NONE self.load_session_ticket () self.socket = socket.socket (socket.AF_INET, socket.SOCK_DGRAM) self.socket.settimeout (self.socket_timeout) self._connected = False self._closed = False self._history = [] self._allow_push = enable_push self._response = None def connect (self): host, port = self.addr try: ipaddress.ip_address(host) server_name = None except ValueError: server_name = host if server_name is not None: self.configuration.server_name = server_name self._quic = QuicConnection ( configuration = self.configuration, session_ticket_handler = self.save_session_ticket ) self._http = H3Connection(self._quic) self._quic.connect(self.addr, now=time.monotonic ()) self.transmit () def close (self): self.transmit () self.recv () #self.socket.close () self._closed = True def transmit (self): dts = self._quic.datagrams_to_send(now=time.monotonic ()) if not dts: return for data, addr in dts: #print ('<---', len (data), repr (data [:30])) sent = self.socket.sendto (data, self.addr) self.recv () def recv (self): while 1: try: data, addr = self.socket.recvfrom (4096) except socket.timeout: break #print ('--->', len (data), repr (data [:30])) self._quic.receive_datagram(data, addr, now = time.monotonic ()) self._process_events() self.transmit () def save_session_ticket (self, ticket): with open(self.session_ticket, "wb") as fp: pickle.dump(ticket, fp) def load_session_ticket (self): try: with open(self.session_ticket, "rb") as fp: self.configuration.session_ticket = pickle.load(fp) except FileNotFoundError: pass def _process_events(self): event = self._quic.next_event() while event is not None: if isinstance(event, events.ConnectionTerminated): self._connected = False self.close () elif isinstance(event, events.HandshakeCompleted): self._connected = True elif isinstance(event, events.PingAcknowledged): pass self.quic_event_received(event) event = self._quic.next_event() def http_event_received (self, http_event): if isinstance(http_event, HeadersReceived): self._response.headers = {k: v for k, v in http_event.headers} elif isinstance(http_event, DataReceived): if http_event.stream_id % 4 == 0: self._response.data += http_event.data elif isinstance(http_event, PushPromiseReceived): push_headers = {} for k, v in http_event.headers: push_headers [k.decode ()] = v.decode () self._response.promises.append (push_headers) def quic_event_received (self, event): # pass event to the HTTP layer if self._response and not isinstance (event, events.StreamDataReceived): self._response.events.append (event) for http_event in self._http.handle_event(event): if not isinstance (http_event, (HeadersReceived, PushPromiseReceived, DataReceived)): # logging only control frames self._response.events.append (http_event) self.http_event_received(http_event) def handle_request (self, request): if self._closed: raise ConnectionClosed if not self._connected: self.connect () self._response = request.response stream_id = self._quic.get_next_available_stream_id() self._response.stream_id = stream_id self._http.send_headers( stream_id=stream_id, headers = [ (b":method", request.method.encode("utf8")), (b":scheme", request.url.scheme.encode("utf8")), (b":authority", request.url.authority.encode("utf8")), (b":path", request.url.full_path.encode("utf8")), (b"user-agent", b"aioquic"), ] + [ (k.encode("utf8"), v.encode("utf8")) for (k, v) in request.headers.items() ], ) self._http.send_data (stream_id=stream_id, data=request.content, end_stream=True) self.transmit() self._response = None return request.response def get (self, url, headers = {}): req = HttpRequest ('GET', 'https://{}{}'.format (self.netloc, url), b'', headers) return self.handle_request (req) def post (self, url, data, headers = {}): req = HttpRequest ('POST', 'https://{}{}'.format (self.netloc, url), data, headers) return self.handle_request (req) def request (self, method, url, data = b'', headers = {}): req = HttpRequest (method, 'https://{}{}'.format (self.netloc, url), data, headers) return self.handle_request (req)