async def _connect(self, urlparsed: ParseResult, verify: bool, ssl_context: SSLContext, http2: bool): """Get reader and writer.""" key = '%s-%s' % (urlparsed.hostname, urlparsed.port) if self.writer: # python 3.6 doesn't have writer.is_closing is_closing = getattr( self.writer, 'is_closing', self.writer._transport.is_closing) # type: ignore else: def is_closing(): return True # noqa if not (self.key and key == self.key and not is_closing()): if self.writer: self.writer.close() if urlparsed.scheme == 'https': ssl_context = ssl_context or ssl.create_default_context( ssl.Purpose.SERVER_AUTH, ) if http2: # flag will be removed when fully http2 support ssl_context.set_alpn_protocols(['h2', 'http/1.1']) if not verify: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE port = urlparsed.port or (443 if urlparsed.scheme == 'https' else 80) self.reader, self.writer = await asyncio.open_connection( urlparsed.hostname, port, ssl=ssl_context) self.temp_key = key await self._connection_made()
async def test_alpn_negotiation(self, server_context: ssl.SSLContext, client_context: ssl.SSLContext) -> None: def serve_sync() -> None: conn, addr = server_sock.accept() conn.settimeout(1) selected_alpn_protocol = conn.selected_alpn_protocol() assert selected_alpn_protocol is not None conn.send(selected_alpn_protocol.encode()) conn.close() server_context.set_alpn_protocols(["dummy1", "dummy2"]) client_context.set_alpn_protocols(["dummy2", "dummy3"]) server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=False) server_sock.settimeout(1) server_sock.bind(("127.0.0.1", 0)) server_sock.listen() server_thread = Thread(target=serve_sync) server_thread.start() async with await connect_tcp(*server_sock.getsockname()) as stream: wrapper = await TLSStream.wrap(stream, hostname="localhost", ssl_context=client_context) assert wrapper.extra(TLSAttribute.alpn_protocol) == "dummy2" server_alpn_protocol = await wrapper.receive() server_thread.join() server_sock.close() assert server_alpn_protocol == b"dummy2"
def _create_ssl_context(self) -> Optional[SSLContext]: ssl_context = None if self.cfg.is_ssl: ssl_context = SSLContext(self.cfg.ssl_version) ssl_context.load_cert_chain(self.cfg.certfile, self.cfg.keyfile) if self.cfg.ca_certs: ssl_context.load_verify_locations(self.cfg.ca_certs) if self.cfg.ciphers: ssl_context.set_ciphers(self.cfg.ciphers) ssl_context.set_alpn_protocols(['h2', 'http/1.1']) return ssl_context
async def test_extra_attributes(self, server_context: ssl.SSLContext, client_context: ssl.SSLContext) -> None: def serve_sync() -> None: conn, addr = server_sock.accept() with conn: conn.settimeout(1) conn.recv(1) server_context.set_alpn_protocols(["h2"]) client_context.set_alpn_protocols(["h2"]) server_sock = server_context.wrap_socket(socket.socket(), server_side=True, suppress_ragged_eofs=True) server_sock.settimeout(1) server_sock.bind(("127.0.0.1", 0)) server_sock.listen() server_thread = Thread(target=serve_sync) server_thread.start() async with await connect_tcp(*server_sock.getsockname()) as stream: wrapper = await TLSStream.wrap( stream, hostname="localhost", ssl_context=client_context, standard_compatible=False, ) async with wrapper: for name, attribute in SocketAttribute.__dict__.items(): if not name.startswith("_"): assert wrapper.extra(attribute) == stream.extra( attribute) assert wrapper.extra(TLSAttribute.alpn_protocol) == "h2" assert isinstance( wrapper.extra(TLSAttribute.channel_binding_tls_unique), bytes) assert isinstance(wrapper.extra(TLSAttribute.cipher), tuple) assert isinstance(wrapper.extra(TLSAttribute.peer_certificate), dict) assert isinstance( wrapper.extra(TLSAttribute.peer_certificate_binary), bytes) assert wrapper.extra(TLSAttribute.server_side) is False assert isinstance(wrapper.extra(TLSAttribute.shared_ciphers), list) assert isinstance(wrapper.extra(TLSAttribute.ssl_object), ssl.SSLObject) assert wrapper.extra(TLSAttribute.standard_compatible) is False assert wrapper.extra( TLSAttribute.tls_version).startswith("TLSv") await wrapper.send(b"\x00") server_thread.join() server_sock.close()
async def _connect(self, urlparsed: ParseResult, verify: bool, ssl_context: SSLContext, dns_info, http2: bool) -> None: """Get reader and writer.""" if not urlparsed.hostname: raise HttpParsingError('missing hostname') key = f'{urlparsed.hostname}-{urlparsed.port}' if self.writer: # python 3.6 doesn't have writer.is_closing is_closing = getattr( self.writer, 'is_closing', self.writer._transport.is_closing) # type: ignore else: def is_closing(): return True # noqa dns_info_copy = dns_info.copy() dns_info_copy['server_hostname'] = dns_info_copy.pop('hostname') if not (self.key and key == self.key and not is_closing()): self.close() if urlparsed.scheme == 'https': ssl_context = ssl_context or ssl.create_default_context( ssl.Purpose.SERVER_AUTH, ) # flag will be removed when fully http2 support if http2: # pragma: no cover ssl_context.set_alpn_protocols(['h2', 'http/1.1']) if not verify: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE else: del dns_info_copy['server_hostname'] port = urlparsed.port or (443 if urlparsed.scheme == 'https' else 80) dns_info_copy['port'] = port self.reader, self.writer = await open_connection( **dns_info_copy, ssl=ssl_context) self.temp_key = key await self._connection_made()
class SyncHTTPConnection(SyncHTTPTransport): def __init__( self, origin: Origin, http2: bool = False, ssl_context: SSLContext = None, socket: SyncSocketStream = None, ): self.origin = origin self.http2 = http2 self.ssl_context = SSLContext() if ssl_context is None else ssl_context self.socket = socket if self.http2: self.ssl_context.set_alpn_protocols(["http/1.1", "h2"]) self.connection: Union[None, SyncHTTP11Connection, SyncHTTP2Connection] = None self.is_http11 = False self.is_http2 = False self.connect_failed = False self.expires_at: Optional[float] = None self.backend = SyncBackend() @property def request_lock(self) -> SyncLock: # We do this lazily, to make sure backend autodetection always # runs within an async context. if not hasattr(self, "_request_lock"): self._request_lock = self.backend.create_lock() return self._request_lock def request( self, method: bytes, url: URL, headers: Headers = None, stream: SyncByteStream = None, timeout: TimeoutDict = None, ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], SyncByteStream]: assert url_to_origin(url) == self.origin with self.request_lock: if self.state == ConnectionState.PENDING: if not self.socket: logger.trace( "open_socket origin=%r timeout=%r", self.origin, timeout ) self.socket = self._open_socket(timeout) self._create_connection(self.socket) elif self.state in (ConnectionState.READY, ConnectionState.IDLE): pass elif self.state == ConnectionState.ACTIVE and self.is_http2: pass else: raise NewConnectionRequired() assert self.connection is not None logger.trace( "connection.request method=%r url=%r headers=%r", method, url, headers ) return self.connection.request(method, url, headers, stream, timeout) def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream: scheme, hostname, port = self.origin timeout = {} if timeout is None else timeout ssl_context = self.ssl_context if scheme == b"https" else None try: return self.backend.open_tcp_stream( hostname, port, ssl_context, timeout ) except Exception: self.connect_failed = True raise def _create_connection(self, socket: SyncSocketStream) -> None: http_version = socket.get_http_version() logger.trace( "create_connection socket=%r http_version=%r", socket, http_version ) if http_version == "HTTP/2": self.is_http2 = True self.connection = SyncHTTP2Connection( socket=socket, backend=self.backend, ssl_context=self.ssl_context ) else: self.is_http11 = True self.connection = SyncHTTP11Connection( socket=socket, ssl_context=self.ssl_context ) @property def state(self) -> ConnectionState: if self.connect_failed: return ConnectionState.CLOSED elif self.connection is None: return ConnectionState.PENDING return self.connection.state def is_connection_dropped(self) -> bool: return self.connection is not None and self.connection.is_connection_dropped() def mark_as_ready(self) -> None: if self.connection is not None: self.connection.mark_as_ready() def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None: if self.connection is not None: logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) self.connection.start_tls(hostname, timeout) logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) self.socket = self.connection.socket def close(self) -> None: with self.request_lock: if self.connection is not None: self.connection.close()
class AsyncHTTPConnection(AsyncHTTPTransport): def __init__( self, origin: Origin, http2: bool = False, uds: str = None, ssl_context: SSLContext = None, socket: AsyncSocketStream = None, local_address: str = None, retries: int = 0, backend: AsyncBackend = None, ): self.origin = origin self.http2 = http2 self.uds = uds self.ssl_context = SSLContext() if ssl_context is None else ssl_context self.socket = socket self.local_address = local_address self.retries = retries if self.http2: self.ssl_context.set_alpn_protocols(["http/1.1", "h2"]) self.connection: Optional[AsyncBaseHTTPConnection] = None self.is_http11 = False self.is_http2 = False self.connect_failed = False self.expires_at: Optional[float] = None self.backend = AutoBackend() if backend is None else backend def __repr__(self) -> str: http_version = "UNKNOWN" if self.is_http11: http_version = "HTTP/1.1" elif self.is_http2: http_version = "HTTP/2" return f"<AsyncHTTPConnection http_version={http_version} state={self.state}>" def info(self) -> str: if self.connection is None: return "Not connected" elif self.state == ConnectionState.PENDING: return "Connecting" return self.connection.info() @property def request_lock(self) -> AsyncLock: # We do this lazily, to make sure backend autodetection always # runs within an async context. if not hasattr(self, "_request_lock"): self._request_lock = self.backend.create_lock() return self._request_lock async def arequest( self, method: bytes, url: URL, headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, ) -> Tuple[int, Headers, AsyncByteStream, dict]: assert url_to_origin(url) == self.origin ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) async with self.request_lock: if self.state == ConnectionState.PENDING: if not self.socket: logger.trace("open_socket origin=%r timeout=%r", self.origin, timeout) self.socket = await self._open_socket(timeout) self._create_connection(self.socket) elif self.state in (ConnectionState.READY, ConnectionState.IDLE): pass elif self.state == ConnectionState.ACTIVE and self.is_http2: pass else: raise NewConnectionRequired() assert self.connection is not None logger.trace("connection.arequest method=%r url=%r headers=%r", method, url, headers) return await self.connection.arequest(method, url, headers, stream, ext) async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream: scheme, hostname, port = self.origin timeout = {} if timeout is None else timeout ssl_context = self.ssl_context if scheme == b"https" else None retries_left = self.retries delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) while True: try: if self.uds is None: return await self.backend.open_tcp_stream( hostname, port, ssl_context, timeout, local_address=self.local_address, ) else: return await self.backend.open_uds_stream( self.uds, hostname, ssl_context, timeout) except (ConnectError, ConnectTimeout): if retries_left <= 0: self.connect_failed = True raise retries_left -= 1 delay = next(delays) await self.backend.sleep(delay) except Exception: # noqa: PIE786 self.connect_failed = True raise def _create_connection(self, socket: AsyncSocketStream) -> None: http_version = socket.get_http_version() logger.trace("create_connection socket=%r http_version=%r", socket, http_version) if http_version == "HTTP/2": from .http2 import AsyncHTTP2Connection self.is_http2 = True self.connection = AsyncHTTP2Connection( socket=socket, backend=self.backend, ssl_context=self.ssl_context) else: from .http11 import AsyncHTTP11Connection self.is_http11 = True self.connection = AsyncHTTP11Connection( socket=socket, ssl_context=self.ssl_context) @property def state(self) -> ConnectionState: if self.connect_failed: return ConnectionState.CLOSED elif self.connection is None: return ConnectionState.PENDING return self.connection.get_state() def is_connection_dropped(self) -> bool: return self.connection is not None and self.connection.is_connection_dropped( ) def mark_as_ready(self) -> None: if self.connection is not None: self.connection.mark_as_ready() async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None: if self.connection is not None: logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) self.socket = await self.connection.start_tls(hostname, timeout) logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) async def aclose(self) -> None: async with self.request_lock: if self.connection is not None: await self.connection.aclose()
class AsyncHTTPConnection(AsyncHTTPTransport): def __init__( self, origin: Tuple[bytes, bytes, int], http2: bool = False, ssl_context: SSLContext = None, ): self.origin = origin self.http2 = http2 self.ssl_context = SSLContext() if ssl_context is None else ssl_context if self.http2: self.ssl_context.set_alpn_protocols(["http/1.1", "h2"]) self.connection: Union[None, AsyncHTTP11Connection, AsyncHTTP2Connection] = None self.is_http11 = False self.is_http2 = False self.connect_failed = False self.expires_at: Optional[float] = None self.backend = AutoBackend() @property def request_lock(self) -> AsyncLock: # We do this lazily, to make sure backend autodetection always # runs within an async context. if not hasattr(self, "_request_lock"): self._request_lock = self.backend.create_lock() return self._request_lock async def request( self, method: bytes, url: Tuple[bytes, bytes, int, bytes], headers: List[Tuple[bytes, bytes]] = None, stream: AsyncByteStream = None, timeout: Dict[str, Optional[float]] = None, ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], AsyncByteStream]: assert url[:3] == self.origin async with self.request_lock: if self.state == ConnectionState.PENDING: try: await self._connect(timeout) except: self.connect_failed = True raise elif self.state in (ConnectionState.READY, ConnectionState.IDLE): pass elif self.state == ConnectionState.ACTIVE and self.is_http2: pass else: raise NewConnectionRequired() assert self.connection is not None return await self.connection.request(method, url, headers, stream, timeout) async def _connect( self, timeout: Dict[str, Optional[float]] = None, ): scheme, hostname, port = self.origin timeout = {} if timeout is None else timeout ssl_context = self.ssl_context if scheme == b"https" else None socket = await self.backend.open_tcp_stream(hostname, port, ssl_context, timeout) http_version = socket.get_http_version() if http_version == "HTTP/2": self.is_http2 = True self.connection = AsyncHTTP2Connection(socket=socket, backend=self.backend) else: self.is_http11 = True self.connection = AsyncHTTP11Connection(socket=socket) @property def state(self) -> ConnectionState: if self.connect_failed: return ConnectionState.CLOSED elif self.connection is None: return ConnectionState.PENDING return self.connection.state def is_connection_dropped(self) -> bool: return self.connection is not None and self.connection.is_connection_dropped( ) def mark_as_ready(self) -> None: if self.connection is not None: self.connection.mark_as_ready() async def start_tls(self, hostname: bytes, timeout: Dict[str, Optional[float]] = None): if self.connection is not None: await self.connection.start_tls(hostname, timeout)
class SyncHTTPConnection(SyncHTTPTransport): def __init__( self, origin: Origin, http1: bool = True, http2: bool = False, keepalive_expiry: float = None, uds: str = None, ssl_context: SSLContext = None, socket: SyncSocketStream = None, local_address: str = None, retries: int = 0, backend: SyncBackend = None, ): self.origin = origin self._http1_enabled = http1 self._http2_enabled = http2 self._keepalive_expiry = keepalive_expiry self._uds = uds self._ssl_context = SSLContext( ) if ssl_context is None else ssl_context self.socket = socket self._local_address = local_address self._retries = retries alpn_protocols: List[str] = [] if http1: alpn_protocols.append("http/1.1") if http2: alpn_protocols.append("h2") self._ssl_context.set_alpn_protocols(alpn_protocols) self.connection: Optional[SyncBaseHTTPConnection] = None self._is_http11 = False self._is_http2 = False self._connect_failed = False self._expires_at: Optional[float] = None self._backend = SyncBackend() if backend is None else backend def __repr__(self) -> str: return f"<SyncHTTPConnection [{self.info()}]>" def info(self) -> str: if self.connection is None: return "Connection failed" if self._connect_failed else "Connecting" return self.connection.info() def should_close(self) -> bool: """ Return `True` if the connection is in a state where it should be closed. This occurs when any of the following occur: * There are no active requests on an HTTP/1.1 connection, and the underlying socket is readable. The only valid state the socket can be readable in if this occurs is when the b"" EOF marker is about to be returned, indicating a server disconnect. * There are no active requests being made and the keepalive timeout has passed. """ if self.connection is None: return False return self.connection.should_close() def is_idle(self) -> bool: """ Return `True` if the connection is currently idle. """ if self.connection is None: return False return self.connection.is_idle() def is_closed(self) -> bool: if self.connection is None: return self._connect_failed return self.connection.is_closed() def is_available(self) -> bool: """ Return `True` if the connection is currently able to accept an outgoing request. This occurs when any of the following occur: * The connection has not yet been opened, and HTTP/2 support is enabled. We don't *know* at this point if we'll end up on an HTTP/2 connection or not, but we *might* do, so we indicate availability. * The connection has been opened, and is currently idle. * The connection is open, and is an HTTP/2 connection. The connection must also not currently be exceeding the maximum number of allowable concurrent streams and must not have exhausted the maximum total number of stream IDs. """ if self.connection is None: return self._http2_enabled and not self.is_closed return self.connection.is_available() @property def request_lock(self) -> SyncLock: # We do this lazily, to make sure backend autodetection always # runs within an async context. if not hasattr(self, "_request_lock"): self._request_lock = self._backend.create_lock() return self._request_lock def handle_request( self, method: bytes, url: URL, headers: Headers, stream: SyncByteStream, extensions: dict, ) -> Tuple[int, Headers, SyncByteStream, dict]: assert url_to_origin(url) == self.origin timeout = cast(TimeoutDict, extensions.get("timeout", {})) with self.request_lock: if self.connection is None: if self._connect_failed: raise NewConnectionRequired() if not self.socket: logger.trace("open_socket origin=%r timeout=%r", self.origin, timeout) self.socket = self._open_socket(timeout) self._create_connection(self.socket) elif not self.connection.is_available(): raise NewConnectionRequired() assert self.connection is not None logger.trace( "connection.handle_request method=%r url=%r headers=%r", method, url, headers, ) return self.connection.handle_request(method, url, headers, stream, extensions) def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream: scheme, hostname, port = self.origin timeout = {} if timeout is None else timeout ssl_context = self._ssl_context if scheme == b"https" else None retries_left = self._retries delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) while True: try: if self._uds is None: return self._backend.open_tcp_stream( hostname, port, ssl_context, timeout, local_address=self._local_address, ) else: return self._backend.open_uds_stream( self._uds, hostname, ssl_context, timeout) except (ConnectError, ConnectTimeout): if retries_left <= 0: self._connect_failed = True raise retries_left -= 1 delay = next(delays) self._backend.sleep(delay) except Exception: # noqa: PIE786 self._connect_failed = True raise def _create_connection(self, socket: SyncSocketStream) -> None: http_version = socket.get_http_version() logger.trace("create_connection socket=%r http_version=%r", socket, http_version) if http_version == "HTTP/2" or (self._http2_enabled and not self._http1_enabled): from .http2 import SyncHTTP2Connection self._is_http2 = True self.connection = SyncHTTP2Connection( socket=socket, keepalive_expiry=self._keepalive_expiry, backend=self._backend, ) else: self._is_http11 = True self.connection = SyncHTTP11Connection( socket=socket, keepalive_expiry=self._keepalive_expiry) def start_tls(self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict = None) -> None: if self.connection is not None: logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) self.socket = self.connection.start_tls(hostname, ssl_context, timeout) logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) def close(self) -> None: with self.request_lock: if self.connection is not None: self.connection.close()