def test_make_channel_binding_tls_server_end_point(mocker): ssl_socket = mocker.Mock() ssl_socket.getpeercert = mocker.Mock(return_value=b"cafe") mock_cert = mocker.Mock() mock_cert.hash_algo = "sha512" mocker.patch("scramp.core.Certificate.load", return_value=mock_cert) result = make_channel_binding("tls-server-end-point", ssl_socket) assert result == ( "tls-server-end-point", b"5\x9dQ\xe2\xc4a\x17g\x1bK\xeci\x98\x9e\x16R\x96}\xe4~D\x15\xfb\xb3\x1fn]=" b"\re?s\x10\xf2\xf8\xa6+\x91i\x9d\x84,iO\x8emDu\xb4\x19\x06i\xa7\x1a\xf1i\xc6" b"K\x81\xcbp\xd1\xaf\xd7", )
def __init__( self, user, host="localhost", database=None, port=5432, password=None, source_address=None, unix_sock=None, ssl_context=None, timeout=None, tcp_keepalive=True, application_name=None, replication=None, ): self._client_encoding = "utf8" self._commands_with_count = ( b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH", b"COPY", b"SELECT", ) self.notifications = deque(maxlen=100) self.notices = deque(maxlen=100) self.parameter_statuses = deque(maxlen=100) if user is None: raise InterfaceError( "The 'user' connection parameter cannot be None") init_params = { "user": user, "database": database, "application_name": application_name, "replication": replication, } for k, v in tuple(init_params.items()): if isinstance(v, str): init_params[k] = v.encode("utf8") elif v is None: del init_params[k] elif not isinstance(v, (bytes, bytearray)): raise InterfaceError( f"The parameter {k} can't be of type {type(v)}.") self.user = init_params["user"] if isinstance(password, str): self.password = password.encode("utf8") else: self.password = password self.autocommit = False self._xid = None self._statement_nums = set() self._caches = {} if unix_sock is None and host is not None: try: self._usock = socket.create_connection((host, port), timeout, source_address) except socket.error as e: raise InterfaceError( f"Can't create a connection to host {host} and port {port} " f"(timeout is {timeout} and source_address is {source_address})." ) from e elif unix_sock is not None: try: if not hasattr(socket, "AF_UNIX"): raise InterfaceError( "attempt to connect to unix socket on unsupported platform" ) self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self._usock.settimeout(timeout) self._usock.connect(unix_sock) except socket.error as e: if self._usock is not None: self._usock.close() raise InterfaceError("communication error") from e else: raise InterfaceError("one of host or unix_sock must be provided") if tcp_keepalive: self._usock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self.channel_binding = None if ssl_context is not None: try: import ssl if ssl_context is True: ssl_context = ssl.create_default_context() request_ssl = getattr(ssl_context, "request_ssl", True) if request_ssl: # Int32(8) - Message length, including self. # Int32(80877103) - The SSL request code. self._usock.sendall(ii_pack(8, 80877103)) resp = self._usock.recv(1) if resp != b"S": raise InterfaceError("Server refuses SSL") self._usock = ssl_context.wrap_socket(self._usock, server_hostname=host) if request_ssl: self.channel_binding = scramp.make_channel_binding( "tls-server-end-point", self._usock) except ImportError: raise InterfaceError( "SSL required but ssl module not available in this python " "installation.") self._sock = self._usock.makefile(mode="rwb") def sock_flush(): try: self._sock.flush() except OSError as e: raise InterfaceError("network error on flush") from e self._flush = sock_flush def sock_read(b): try: return self._sock.read(b) except OSError as e: raise InterfaceError("network error on read") from e self._read = sock_read def sock_write(d): try: self._sock.write(d) except OSError as e: raise InterfaceError("network error on write") from e self._write = sock_write self._backend_key_data = None self.pg_types = defaultdict(lambda: string_in, PG_TYPES) self.py_types = dict(PY_TYPES) self.message_types = { NOTICE_RESPONSE: self.handle_NOTICE_RESPONSE, AUTHENTICATION_REQUEST: self.handle_AUTHENTICATION_REQUEST, PARAMETER_STATUS: self.handle_PARAMETER_STATUS, BACKEND_KEY_DATA: self.handle_BACKEND_KEY_DATA, READY_FOR_QUERY: self.handle_READY_FOR_QUERY, ROW_DESCRIPTION: self.handle_ROW_DESCRIPTION, ERROR_RESPONSE: self.handle_ERROR_RESPONSE, EMPTY_QUERY_RESPONSE: self.handle_EMPTY_QUERY_RESPONSE, DATA_ROW: self.handle_DATA_ROW, COMMAND_COMPLETE: self.handle_COMMAND_COMPLETE, PARSE_COMPLETE: self.handle_PARSE_COMPLETE, BIND_COMPLETE: self.handle_BIND_COMPLETE, CLOSE_COMPLETE: self.handle_CLOSE_COMPLETE, PORTAL_SUSPENDED: self.handle_PORTAL_SUSPENDED, NO_DATA: self.handle_NO_DATA, PARAMETER_DESCRIPTION: self.handle_PARAMETER_DESCRIPTION, NOTIFICATION_RESPONSE: self.handle_NOTIFICATION_RESPONSE, COPY_DONE: self.handle_COPY_DONE, COPY_DATA: self.handle_COPY_DATA, COPY_IN_RESPONSE: self.handle_COPY_IN_RESPONSE, COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE, } # Int32 - Message length, including self. # Int32(196608) - Protocol version number. Version 3.0. # Any number of key/value pairs, terminated by a zero byte: # String - A parameter name (user, database, or options) # String - Parameter value protocol = 196608 val = bytearray(i_pack(protocol)) for k, v in init_params.items(): val.extend(k.encode("ascii") + NULL_BYTE + v + NULL_BYTE) val.append(0) self._write(i_pack(len(val) + 4)) self._write(val) self._flush() code = self.error = None while code not in (READY_FOR_QUERY, ERROR_RESPONSE): code, data_len = ci_unpack(self._read(5)) self.message_types[code](self._read(data_len - 4), None) if self.error is not None: raise self.error self.in_transaction = False