async def _serve_client(stream): if addr.use_ssl: ssl_context = ssl.create_default_context() stream = trio.SSLStream( stream, ssl_context, server_hostname=addr.hostname, server_side=True ) await entry_point(stream)
async def open_tcp_stream( self, hostname: bytes, port: int, ssl_context: Optional[SSLContext], timeout: TimeoutDict, *, local_address: Optional[str], ) -> AsyncSocketStream: connect_timeout = none_as_inf(timeout.get("connect")) # Trio will support local_address from 0.16.1 onwards. # We only include the keyword argument if a local_address # argument has been passed. kwargs: dict = {} if local_address is None else {"local_address": local_address} exc_map = { trio.TooSlowError: ConnectTimeout, trio.BrokenResourceError: ConnectError, } with map_exceptions(exc_map): with trio.fail_after(connect_timeout): stream: trio.abc.Stream = await trio.open_tcp_stream( hostname, port, **kwargs ) if ssl_context is not None: stream = trio.SSLStream( stream, ssl_context, server_hostname=hostname.decode("ascii") ) await stream.do_handshake() return SocketStream(stream=stream)
async def open_stream_to_backend(hostname, port, use_ssl): stream = await trio.open_tcp_stream(hostname, port) if use_ssl: ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) ssl_context.load_default_certs() stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname) return stream
async def make_socket(self, af, socktype, proto=0, source=None, destination=None, timeout=None, ssl_context=None, server_hostname=None): s = trio.socket.socket(af, socktype, proto) stream = None try: if source: await s.bind(_lltuple(source, af)) if socktype == socket.SOCK_STREAM: with _maybe_timeout(timeout): await s.connect(_lltuple(destination, af)) except Exception: # pragma: no cover s.close() raise if socktype == socket.SOCK_DGRAM: return DatagramSocket(s) elif socktype == socket.SOCK_STREAM: stream = trio.SocketStream(s) s = None tls = False if ssl_context: tls = True try: stream = trio.SSLStream(stream, ssl_context, server_hostname=server_hostname) except Exception: # pragma: no cover await stream.aclose() raise return StreamSocket(af, stream, tls) raise NotImplementedError('unsupported socket ' + f'type {socktype}') # pragma: no cover
async def open_uds_stream( self, path: str, hostname: bytes, ssl_context: Optional[SSLContext], timeout: TimeoutDict, ) -> AsyncSocketStream: connect_timeout = none_as_inf(timeout.get("connect")) exc_map = { trio.TooSlowError: ConnectTimeout, trio.BrokenResourceError: ConnectError, } with map_exceptions(exc_map): with trio.fail_after(connect_timeout): stream: trio.abc.Stream = await trio.open_unix_socket(path) if ssl_context is not None: stream = trio.SSLStream( stream, ssl_context, server_hostname=hostname.decode("ascii")) await stream.do_handshake() return SocketStream(stream=stream)
async def open_tcp_stream( self, hostname: bytes, port: int, ssl_context: Optional[SSLContext], timeout: Dict[str, Optional[float]], ) -> AsyncSocketStream: connect_timeout = none_as_inf(timeout.get("connect")) exc_map = { trio.TooSlowError: ConnectTimeout, trio.BrokenResourceError: ConnectError, } with map_exceptions(exc_map): with trio.fail_after(connect_timeout): stream: trio.SocketStream = await trio.open_tcp_stream( hostname, port) if ssl_context is not None: stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname) await stream.do_handshake() return SocketStream(stream=stream)
async def open_tcp_stream( self, hostname: bytes, port: int, ssl_context: Optional[SSLContext], timeout: TimeoutDict, local_addr: Optional[bytes], ) -> AsyncSocketStream: # trio doesn't currently support specifying the local address; it will # as of 0.16.1. if local_addr: raise NotImplementedError() connect_timeout = none_as_inf(timeout.get("connect")) exc_map = { trio.TooSlowError: ConnectTimeout, trio.BrokenResourceError: ConnectError, } with map_exceptions(exc_map): with trio.fail_after(connect_timeout): stream: trio.abc.Stream = await trio.open_tcp_stream( hostname, port) if ssl_context is not None: stream = trio.SSLStream( stream, ssl_context, server_hostname=hostname.decode("ascii")) await stream.do_handshake() return SocketStream(stream=stream)
async def start_tls( self, ssl_context: ssl.SSLContext, server_hostname: str = None, timeout: float = None, ) -> AsyncNetworkStream: timeout_or_inf = float("inf") if timeout is None else timeout exc_map = { trio.TooSlowError: ConnectTimeout, trio.BrokenResourceError: ConnectError, } ssl_stream = trio.SSLStream( self._stream, ssl_context=ssl_context, server_hostname=server_hostname, https_compatible=True, server_side=False, ) with map_exceptions(exc_map): try: with trio.fail_after(timeout_or_inf): await ssl_stream.do_handshake() except Exception as exc: # pragma: nocover await self.aclose() raise exc return TrioStream(ssl_stream)
async def _http_send(target): stream = await trio.open_tcp_stream(backend_addr.hostname, backend_addr.port) if backend_addr.use_ssl: ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) ssl_context.load_default_certs() stream = trio.SSLStream(stream, ssl_context, server_hostname=backend_addr.hostname) if isinstance(target, str): target = target.encode("utf8") req = b"GET %s HTTP/1.1\r\nHost: %s \r\n\r\n" % ( target, backend_addr.hostname.encode("idna"), ) await stream.send_all(req) data = await stream.receive_some() await stream.aclose() dataparts = data.split(b"\r\n\r\n") if dataparts[1][:9] == b"<!DOCTYPE": return data.decode("utf8") else: return dataparts[0].decode("utf8")
async def _establish_ssl_stream(self, dest_replica, lot): """ A task try to connect to dest_replica on an infinite loop until informed to quit (on BFT client exit). There hare 2 states: 1) Connected to dest_replica - in that case, park in the lot. 2) Disconnected from dest_replica - in that case, try to connect to it. On success, insert the new SSL stream into self.ssl_streams and move to parking until un-parked. On failure, sleep for 0.1 sec, and retry. SSL stream might be remove from self.ssl_streams while sending or receiving data, after finding out that connection is closed or broken. """ if self.exit_flag: return server_cert_path = self._get_cert_path(dest_replica.id, is_client=False) client_cert_path = self._get_cert_path(self.client_id, is_client=True) client_pk_path = self._get_private_key_path(self.client_id, is_client=True) # Create an SSl context - enable CERT_REQUIRED and check_hostname = True ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) # Verify server certificate using this trusted path ssl_context.load_verify_locations(cafile=server_cert_path) # Load my private key and certificate ssl_context.load_cert_chain(client_cert_path, client_pk_path) # Server hostname to be verified must be taken from create_tls_certs.sh server_hostname = self.CERT_DOMAIN_FORMAT % dest_replica.id dest_addr = (dest_replica.ip, dest_replica.port) ssl_stream = tcp_stream = None # initial state of the event should be True, we want to connect while not self.exit_flag: try: # Open TCP stream and connect to server tcp_stream = await trio.open_tcp_stream( str(dest_replica.ip), int(dest_replica.port)) # Wrap this stream with SSL stream, pass server_hostname to be verified ssl_stream = trio.SSLStream(tcp_stream, ssl_context, server_hostname=server_hostname, https_compatible=False) # Wait for handshake to finish (we want to be on the safe side - after this we are sure # connection is open) await ssl_stream.do_handshake() # Success! keep stream in dictionary and break out if not self.exit_flag: self.ssl_streams[dest_addr] = ssl_stream tcp_stream = ssl_stream = None # park the task till it is woken by unpark() await lot.park() if dest_addr in self.ssl_streams: # delete and close the stream await self._close_ssl_stream(dest_addr) except (OSError, trio.BrokenResourceError): await trio.sleep(0.1) if ssl_stream: await ssl_stream.aclose() elif tcp_stream: await tcp_stream.aclose() if dest_addr in self.ssl_streams: await self._close_ssl_stream(dest_addr)
async def starttls_client(self, server_hostname: str) -> None: context = self.ssl_context_or_default_client() self.stream = trio.SSLStream( self.stream, ssl_context=context, server_hostname=server_hostname, )
async def starttls_server(self) -> None: #log = logger.getChild ( 'TrioTransport.starttls_server' ) context = self.ssl_context_or_default_server() self.stream = trio.SSLStream( self.stream, ssl_context=context, server_side=True, )
async def start_tls(self, server_hostname, ssl_context): wrapped = trio.SSLStream( self._stream, ssl_context, server_hostname=server_hostname, https_compatible=True, ) await wrapped.do_handshake() return TrioSocket(wrapped)
async def ssl_wrapper(): sock = await socket_connect() sock = trio.SSLStream( sock, ssl_context, server_hostname=server_hostname ) await sock.do_handshake() return sock
async def tls_connect(server_hostname, tcp_port): tcp_stream = await trio.open_tcp_stream(server_hostname, tcp_port) ssl_context = tls_context() tls_stream = trio.SSLStream(transport_stream=tcp_stream, ssl_context=ssl_context, server_hostname=server_hostname, https_compatible=True) await tls_stream.do_handshake() return tls_stream
async def _serve_client(stream): if ssl_context: stream = trio.SSLStream(stream, ssl_context, server_side=True) try: await backend.handle_client(stream) except Exception: # If we are here, something unexpected happened... logger.exception("Unexpected crash") await stream.aclose()
async def start_tls(self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout) -> "SocketStream": connect_timeout = none_as_inf(timeout.connect_timeout) ssl_stream = trio.SSLStream(self.stream, ssl_context=ssl_context, server_hostname=hostname) with trio.move_on_after(connect_timeout): await ssl_stream.do_handshake() return SocketStream(ssl_stream) raise ConnectTimeout()
def _upgrade_stream_to_ssl(raw_stream, hostname): # The ssl context should be generated once and stored into the config # however this is tricky (should ssl configuration be stored per device ?) cafile = os.environ.get("SSL_CAFILE") ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) if cafile: ssl_context.load_verify_locations(cafile) else: ssl_context.load_default_certs() return trio.SSLStream(raw_stream, ssl_context, server_hostname=hostname)
async def tls_connect(server_hostname, tcp_port): tcp_stream = await trio.open_tcp_stream(server_hostname, tcp_port) ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE ssl_stream = trio.SSLStream(transport_stream=tcp_stream, ssl_context=ssl_context, server_hostname=server_hostname, https_compatible=True) await ssl_stream.do_handshake() return ssl_stream return tcp_stream
async def start_tls(self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout) -> "SocketStream": connect_timeout = _or_inf(timeout.connect_timeout) ssl_stream = trio.SSLStream(self.stream, ssl_context=ssl_context, server_hostname=hostname) with trio.move_on_after(connect_timeout) as cancel_scope: await ssl_stream.do_handshake() if cancel_scope.cancelled_caught: raise ConnectTimeout() return SocketStream(ssl_stream, self.timeout)
async def tls_connect(server_hostname: AnyStr, tcp_port: int) -> trio.abc.Stream: tcp_stream = await trio.open_tcp_stream(server_hostname, tcp_port) ssl_context = ssl.create_default_context() ssl_context.check_hostname = True ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_stream = trio.SSLStream(transport_stream=tcp_stream, ssl_context=ssl_context, server_hostname=server_hostname, https_compatible=True) await ssl_stream.do_handshake() return ssl_stream
async def _serve_client(stream: trio.abc.Stream) -> None: if ssl_context: stream = trio.SSLStream(stream, ssl_context, server_side=True) try: await backend.handle_client(stream) except ConnectionError: # Should be handled by the reconnection logic (see `_run_and_retry_back`) raise except Exception: # If we are here, something unexpected happened... logger.exception("Unexpected crash") await stream.aclose()
async def open_ssl_over_tcp_stream( host, port, *, https_compatible=False, ssl_context=None, # No trailing comma b/c bpo-9232 (fixed in py36) happy_eyeballs_delay=DEFAULT_DELAY, ): """Make a TLS-encrypted Connection to the given host and port over TCP. This is a convenience wrapper that calls :func:`open_tcp_stream` and wraps the result in an :class:`~trio.SSLStream`. This function does not perform the TLS handshake; you can do it manually by calling :meth:`~trio.SSLStream.do_handshake`, or else it will be performed automatically the first time you send or receive data. Args: host (bytes or str): The host to connect to. We require the server to have a TLS certificate valid for this hostname. port (int): The port to connect to. https_compatible (bool): Set this to True if you're connecting to a web server. See :class:`~trio.SSLStream` for details. Default: False. ssl_context (:class:`~ssl.SSLContext` or None): The SSL context to use. If None (the default), :func:`ssl.create_default_context` will be called to create a context. happy_eyeballs_delay (float): See :func:`open_tcp_stream`. Returns: trio.SSLStream: the encrypted connection to the server. """ tcp_stream = await trio.open_tcp_stream( host, port, happy_eyeballs_delay=happy_eyeballs_delay, ) if ssl_context is None: ssl_context = ssl.create_default_context() return trio.SSLStream( tcp_stream, ssl_context, server_hostname=host, https_compatible=https_compatible, )
async def start_tls(self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict) -> "SocketStream": connect_timeout = none_as_inf(timeout.get("connect")) exc_map = { trio.TooSlowError: ConnectTimeout, trio.BrokenResourceError: ConnectError, } ssl_stream = trio.SSLStream( self.stream, ssl_context=ssl_context, server_hostname=hostname.decode("ascii"), ) with map_exceptions(exc_map): with trio.fail_after(connect_timeout): await ssl_stream.do_handshake() return SocketStream(ssl_stream)
async def open_tcp_stream( self, hostname: str, port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: Timeout, ) -> SocketStream: connect_timeout = none_as_inf(timeout.connect_timeout) with trio.move_on_after(connect_timeout): stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port) if ssl_context is not None: stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname) await stream.do_handshake() return SocketStream(stream=stream) raise ConnectTimeout()
async def start_tls(self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig) -> "SocketStream": # Check that the write buffer is empty. We should never start a TLS stream # while there is still pending data to write. assert self.write_buffer == b"" connect_timeout = _or_inf(timeout.connect_timeout) ssl_stream = trio.SSLStream(self.stream, ssl_context=ssl_context, server_hostname=hostname) with trio.move_on_after(connect_timeout) as cancel_scope: await ssl_stream.do_handshake() if cancel_scope.cancelled_caught: raise ConnectTimeout() return SocketStream(ssl_stream, self.timeout)
async def execute(self, connections, task_status=trio.TASK_STATUS_IGNORED): with trio.CancelScope() as self.connection: self.duration = None t = time.monotonic() print(f"[{self.name}] Trying {self.ip}", end="\033[K\r", flush=True) async with await trio.open_tcp_stream(self.ip, 443) as sock: sock = trio.SSLStream(sock, server_hostname=self.host, https_compatible=True, ssl_context=self.ssl) await sock.do_handshake() cert = sock.getpeercert().get('subject') cert = cert and {k: v for t in cert for k, v in t }.get('commonName') or 'not validated' self.sock = sock self.conn = h2.connection.H2Connection( config=h2.config.H2Configuration(client_side=True, header_encoding="UTF-8")) self.conn.initiate_connection() print(f"[{self.name}] {self.ip} connected, cert {cert}") self.reason = None connections.add(self) try: async with trio.open_nursery() as nursery: task_status.started() nursery.start_soon(self.send_task) nursery.start_soon(self.recv_task) finally: with trio.move_on_after(1) as cleanup: cleanup.shield = True connections.remove(self) self.exited.set() self.duration = time.monotonic() - t if self.reason is None: self.reason = "we canceled" if self.connection.cancel_called else "disconnected" requests = f"answered {self.successes}/{self.attempted}" if self.attempted else "no requests done" print( f"[{self.name}] {self.ip} {self.reason} after {self.duration:.2f} s, {requests}" ) for stream in list(self.streams.values()): await stream.aclose()
def _upgrade_stream_to_ssl(raw_stream: trio.abc.Stream, hostname: str) -> trio.abc.Stream: # The ssl context should be generated once and stored into the config # however this is tricky (should ssl configuration be stored per device ?) # Don't load default system certificates and rely on our own instead. # This is because system certificates are less reliable (and system # certificates are tried first, so they can lead to a failure even if # we bundle a valid certificate...) # Certifi provides Mozilla's carefully curated collection of Root Certificates. ssl_context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cadata=certifi.contents()) # Also provide custom certificates if any cafile = os.environ.get("SSL_CAFILE") if cafile: ssl_context.load_verify_locations(cafile) return trio.SSLStream(raw_stream, ssl_context, server_hostname=hostname)
async def open_uds_stream( self, path: str, hostname: typing.Optional[str], ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, ) -> SocketStream: connect_timeout = _or_inf(timeout.connect_timeout) with trio.move_on_after(connect_timeout) as cancel_scope: stream: trio.SocketStream = await trio.open_unix_socket(path) if ssl_context is not None: stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname) await stream.do_handshake() if cancel_scope.cancelled_caught: raise ConnectTimeout() return SocketStream(stream=stream, timeout=timeout)
async def demo_client(client_raw_stream): client_ssl_context = ssl.create_default_context() # Set up the client's SSLContext to trust our fake CA, that signed # our server cert, so that it can validate server's cert. ca.configure_trust(client_ssl_context) # Set up the client's SSLContext to use our fake client cert client_cert.configure_cert(client_ssl_context) client_ssl_stream = trio.SSLStream( client_raw_stream, client_ssl_context, # Tell the client that it's looking for a trusted cert for this # particular hostname (must match what we passed to issue_cert) server_hostname="test-host.example.org", ) assert await client_ssl_stream.receive_some(1) == b"x" print("Client successfully received data over the encrypted channel!") print("Server cert looks like:", client_ssl_stream.getpeercert())