class Bolt4x3(Bolt4x2): """ Protocol handler for Bolt 4.3. This is supported by Neo4j version 4.3. """ PROTOCOL_VERSION = Version(4, 3) def route(self, database=None, bookmarks=None): def fail(md): from neo4j._exceptions import BoltRoutingError code = md.get("code") if code == "Neo.ClientError.Database.DatabaseNotFound": return # surface this error to the user elif code == "Neo.ClientError.Procedure.ProcedureNotFound": raise BoltRoutingError("Server does not support routing", self.unresolved_address) else: raise BoltRoutingError("Routing support broken on server", self.unresolved_address) routing_context = self.routing_context or {} log.debug("[#%04X] C: ROUTE %r %r", self.local_port, routing_context, database) metadata = {} if bookmarks is None: bookmarks = [] else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, database), response=Response(self, on_success=metadata.update, on_failure=fail)) self.send_all() self.fetch_all() return [metadata.get("rt")] def hello(self): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) if "connection.recv_timeout_seconds" in self.configuration_hints: recv_timeout = self.configuration_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: self.socket.settimeout(recv_timeout) else: log.info("[#%04X] Server supplied an invalid value for " "connection.recv_timeout_seconds (%r). Make sure " "the server and network is set up correctly.", self.local_port, recv_timeout) headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) if "credentials" in logged_headers: logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), response=InitResponse(self, on_success=on_success)) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent)
class Bolt4x3(Bolt4x2): """ Protocol handler for Bolt 4.3. This is supported by Neo4j version 4.3. """ PROTOCOL_VERSION = Version(4, 3) def route(self, database=None, bookmarks=None): def fail(md): from neo4j._exceptions import BoltRoutingError code = md.get("code") if code == "Neo.ClientError.Database.DatabaseNotFound": return # surface this error to the user elif code == "Neo.ClientError.Procedure.ProcedureNotFound": raise BoltRoutingError("Server does not support routing", self.unresolved_address) else: raise BoltRoutingError("Routing support broken on server", self.unresolved_address) routing_context = self.routing_context or {} log.debug("[#%04X] C: ROUTE %r %r", self.local_port, routing_context, database) metadata = {} if bookmarks is None: bookmarks = [] else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, database), response=Response(self, on_success=metadata.update, on_failure=fail)) self.send_all() self.fetch_all() return [metadata.get("rt")]
class Bolt4x2(Bolt4x1): """ Protocol handler for Bolt 4.2. This is supported by Neo4j version 4.2. """ PROTOCOL_VERSION = Version(4, 2)
def version_list(cls, versions, limit=4): """ Return a list of supported protocol versions in order of preference. The number of protocol versions (or ranges) returned is limited to four. """ ranges_supported = versions[0] >= Version(4, 3) if versions and ranges_supported: start, end = 0, 0 first_major = versions[start][0] minors = [] for end, version in enumerate(versions): if version[0] == first_major: minors.append(version[1]) else: break new_versions = ([Version(first_major, minors)] + versions[1:end])[:(limit - 1)] try: new_versions.append(versions[end]) except IndexError: pass return new_versions else: return versions[:limit]
async def _handshake(cls, reader, writer, protocol_version): """ Carry out a Bolt handshake, optionally requesting a specific protocol version. :param reader: :param writer: :param protocol_version: :return: :raise BoltConnectionLost: if an I/O error occurs on the underlying socket connection :raise BoltHandshakeError: if handshake completes without a successful negotiation """ local_address = Address(writer.transport.get_extra_info("sockname")) remote_address = Address(writer.transport.get_extra_info("peername")) handlers = cls.protocol_handlers(protocol_version) if not handlers: raise ValueError( "No protocol handlers available (requested Bolt %r)", protocol_version) offered_versions = sorted(handlers.keys(), reverse=True)[:4] request_data = MAGIC + b"".join( v.to_bytes() for v in offered_versions).ljust(16, b"\x00") log.debug("[#%04X] C: <HANDSHAKE> %r", local_address.port_number, request_data) writer.write(request_data) await writer.drain() response_data = await reader.readexactly(4) log.debug("[#%04X] S: <HANDSHAKE> %r", local_address.port_number, response_data) try: agreed_version = Version.from_bytes(response_data) except ValueError as err: writer.close() raise BoltHandshakeError( "Unexpected handshake response %r" % response_data, remote_address, request_data, response_data) from err try: subclass = handlers[agreed_version] except KeyError: log.debug("Unsupported Bolt protocol version %s", agreed_version) raise BoltHandshakeError("Unsupported Bolt protocol version", remote_address, request_data, response_data) else: return subclass
class Bolt4x1(Bolt4x0): """ Protocol handler for Bolt 4.1. This is supported by Neo4j versions 4.1 and 4.2. """ PROTOCOL_VERSION = Version(4, 1) def get_base_headers(self): """ Bolt 4.1 passes the routing context, originally taken from the URI, into the connection initialisation message. This enables server-side routing to propagate the same behaviour through its driver. """ return { "user_agent": self.user_agent, "routing": self.routing_context, }
def supports_multi_db(self): """ Check if the server or cluster supports multi-databases. :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. :rtype: bool """ from neo4j.io._bolt4x0 import Bolt4x0 multi_database = False cx = self._pool.acquire( access_mode=READ_ACCESS, timeout=self._pool.workspace_config.connection_acquisition_timeout, database=self._pool.workspace_config.database) # TODO: This logic should be inside the Bolt subclasses, because it can change depending on Bolt Protocol Version. if cx.PROTOCOL_VERSION >= Bolt4x0.PROTOCOL_VERSION and cx.server_info.version_info( ) >= Version(4, 0, 0): multi_database = True self._pool.release(cx) return multi_database
result = session.run("X") list(result.keys()) def test_should_not_allow_empty_statements(session): with raises(ValueError): _ = session.run("") def test_statement_object(session): value = session.run(Statement("RETURN $x"), x=1).single().value() assert value == 1 @pytest.mark.parametrize("test_input, neo4j_version", [ ("CALL dbms.getTXMetaData", Version(3, 0)), ("CALL tx.getMetaData", Version(4, 0)), ]) def test_autocommit_transactions_should_support_metadata( session, test_input, neo4j_version): # python -m pytest tests/integration/test_autocommit.py -s -r fEsxX -k test_autocommit_transactions_should_support_metadata metadata_in = {"foo": "bar"} result = session.run("RETURN 1") value = result.single().value() summary = result.summary() server_agent = summary.server.agent try: statement = Statement(test_input, metadata=metadata_in) result = session.run(statement)
class Bolt4x0(Bolt): """ Protocol handler for Bolt 4.0. This is supported by Neo4j versions 4.0, 4.1 and 4.2. """ PROTOCOL_VERSION = Version(4, 0) supports_multiple_results = True supports_multiple_databases = True @property def encrypted(self): return isinstance(self.socket, SSLSocket) @property def der_encoded_server_certificate(self): return self.socket.getpeercert(binary_form=True) @property def local_port(self): try: return self.socket.getsockname()[1] except OSError: return 0 def get_base_headers(self): return { "user_agent": self.user_agent, } def hello(self): headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) if "credentials" in logged_headers: logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), response=InitResponse(self, on_success=self.server_info.update)) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) def route(self, database=None, bookmarks=None): metadata = {} records = [] def fail(md): from neo4j._exceptions import BoltRoutingError code = md.get("code") if code == "Neo.ClientError.Database.DatabaseNotFound": return # surface this error to the user elif code == "Neo.ClientError.Procedure.ProcedureNotFound": raise BoltRoutingError("Server does not support routing", self.unresolved_address) else: raise BoltRoutingError("Routing support broken on server", self.unresolved_address) if database is None: # default database self.run( "CALL dbms.routing.getRoutingTable($context)", {"context": self.routing_context}, mode="r", bookmarks=bookmarks, db=SYSTEM_DATABASE, on_success=metadata.update, on_failure=fail ) else: self.run( "CALL dbms.routing.getRoutingTable($context, $database)", {"context": self.routing_context, "database": database}, mode="r", bookmarks=bookmarks, db=SYSTEM_DATABASE, on_success=metadata.update, on_failure=fail ) self.pull(on_success=metadata.update, on_records=records.extend) self.send_all() self.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): if not parameters: parameters = {} extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified if db: extra["db"] = db if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError("Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError("Timeout must be specified as a number of seconds") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) if query.upper() == u"COMMIT": self._append(b"\x10", fields, CommitResponse(self, **handlers)) else: self._append(b"\x10", fields, Response(self, **handlers)) self._is_reset = False def discard(self, n=-1, qid=-1, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) self._append(b"\x2F", (extra,), Response(self, **handlers)) def pull(self, n=-1, qid=-1, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) self._append(b"\x3F", (extra,), Response(self, **handlers)) self._is_reset = False def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified if db: extra["db"] = db if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError("Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError("Timeout must be specified as a number of seconds") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra,), Response(self, **handlers)) self._is_reset = False def commit(self, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) self._append(b"\x12", (), CommitResponse(self, **handlers)) def rollback(self, **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append(b"\x13", (), Response(self, **handlers)) def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) self._append(b"\x0F", response=Response(self, on_failure=fail)) self.send_all() self.fetch_all() self._is_reset = True def fetch_message(self): """ Receive at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ if self._closed: raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) if self._defunct: raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) if not self.responses: return 0, 0 # Receive exactly one message details, summary_signature, summary_metadata = next(self.inbox) if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data self.responses[0].on_records(details) if summary_signature is None: return len(details), 0 response = self.responses.popleft() response.complete = True if summary_signature == b"\x70": log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": log.debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: self.pool.deactivate(address=self.unresolved_address), raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure(address=self.unresolved_address), raise except Neo4jError as e: if self.pool and e.invalidates_all_connections(): self.pool.mark_all_stale() raise else: raise BoltProtocolError("Unexpected response message with signature " "%02X" % ord(summary_signature), self.unresolved_address) return len(details), 1 def close(self): """ Close the connection. """ if not self._closed: if not self._defunct: log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", ()) try: self._send_all() except (OSError, BoltError, DriverError): pass log.debug("[#%04X] C: <CLOSE>", self.local_port) try: self.socket.close() except OSError: pass finally: self._closed = True def closed(self): return self._closed def defunct(self): return self._defunct
class Bolt3(Bolt): PROTOCOL_VERSION = Version(3, 0) # The socket in_use = False # The socket _closed = False # The socket _defunct = False #: The pool of which this connection is a member pool = None def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None): self.unresolved_address = unresolved_address self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), Bolt3.PROTOCOL_VERSION) self.outbox = Outbox() self.inbox = Inbox(self.socket, on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = max_connection_lifetime self._creation_timestamp = perf_counter() self.supports_multiple_results = False self.supports_multiple_databases = False self._is_reset = True # Determine the user agent if user_agent: self.user_agent = user_agent else: self.user_agent = get_user_agent() # Determine auth details if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j import Auth self.auth_dict = vars(Auth("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None") @property def encrypted(self): return isinstance(self.socket, SSLSocket) @property def der_encoded_server_certificate(self): return self.socket.getpeercert(binary_form=True) @property def local_port(self): try: return self.socket.getsockname()[1] except IOError: return 0 def hello(self): headers = {"user_agent": self.user_agent} headers.update(self.auth_dict) logged_headers = dict(headers) if "credentials" in logged_headers: logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), response=InitResponse(self, on_success=self.server_info.metadata.update)) self.send_all() self.fetch_all() def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): if db is not None: raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db)) if not parameters: parameters = {} extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError("Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError("Timeout must be specified as a number of seconds") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) if query.upper() == u"COMMIT": self._append(b"\x10", fields, CommitResponse(self, **handlers)) else: self._append(b"\x10", fields, Response(self, **handlers)) self._is_reset = False def discard(self, n=-1, qid=-1, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: DISCARD_ALL", self.local_port) self._append(b"\x2F", (), Response(self, **handlers)) def pull(self, n=-1, qid=-1, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append(b"\x3F", (), Response(self, **handlers)) self._is_reset = False def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): if db is not None: raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db)) extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError("Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError("Timeout must be specified as a number of seconds") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra,), Response(self, **handlers)) self._is_reset = False def commit(self, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) self._append(b"\x12", (), CommitResponse(self, **handlers)) def rollback(self, **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append(b"\x13", (), Response(self, **handlers)) def _append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. :arg signature: the signature of the message :arg fields: the fields of the message as a tuple :arg response: a response object to handle callbacks """ self.packer.pack_struct(signature, fields) self.outbox.chunk() self.outbox.chunk() self.responses.append(response) def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) self._append(b"\x0F", response=Response(self, on_failure=fail)) self.send_all() self.fetch_all() self._is_reset = True def _send_all(self): data = self.outbox.view() if data: self.socket.sendall(data) self.outbox.clear() def send_all(self): """ Send all queued messages to the server. """ if self.closed(): raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) if self.defunct(): raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) try: self._send_all() except (IOError, OSError) as error: log.error("Failed to write data to connection " "{!r} ({!r}); ({!r})". format(self.unresolved_address, self.server_info.address, "; ".join(map(repr, error.args)))) if self.pool: self.pool.deactivate(address=self.unresolved_address) raise def fetch_message(self): """ Receive at least one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ if self._closed: raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) if self._defunct: raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) if not self.responses: return 0, 0 # Receive exactly one message try: details, summary_signature, summary_metadata = next(self.inbox) except (IOError, OSError) as error: log.error("Failed to read data from connection " "{!r} ({!r}); ({!r})". format(self.unresolved_address, self.server_info.address, "; ".join(map(repr, error.args)))) if self.pool: self.pool.deactivate(address=self.unresolved_address) raise if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data self.responses[0].on_records(details) if summary_signature is None: return len(details), 0 response = self.responses.popleft() response.complete = True if summary_signature == b"\x70": log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": log.debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: self.pool.deactivate(address=self.unresolved_address), raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure(address=self.unresolved_address), raise else: raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) return len(details), 1 def _set_defunct(self, error=None): direct_driver = isinstance(self.pool, BoltPool) message = ("Failed to read from defunct connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) log.error(message) # We were attempting to receive data but the connection # has unexpectedly terminated. So, we need to close the # connection from the client side, and remove the address # from the connection pool. self._defunct = True self.close() if self.pool: self.pool.deactivate(address=self.unresolved_address) # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are # unable to confirm that the COMMIT completed successfully. for response in self.responses: if isinstance(response, CommitResponse): raise BoltIncompleteCommitError(message, address=None) if direct_driver: raise ServiceUnavailable(message) else: raise SessionExpired(message) def timedout(self): return 0 <= self._max_connection_lifetime <= perf_counter() - self._creation_timestamp def fetch_all(self): """ Fetch all outstanding messages. :return: 2-tuple of number of detail messages and number of summary messages fetched """ detail_count = summary_count = 0 while self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = self.fetch_message() detail_count += detail_delta summary_count += summary_delta return detail_count, summary_count def close(self): """ Close the connection. """ if not self._closed: if not self._defunct: log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", ()) try: self._send_all() except: pass log.debug("[#%04X] C: <CLOSE>", self.local_port) try: self.socket.close() except IOError: pass finally: self._closed = True def closed(self): return self._closed def defunct(self): return self._defunct
class Bolt3(Bolt): protocol_version = Version(3, 0) server_agent = None connection_id = None def __init__(self, reader, writer): self._courier = Courier(reader, writer, self.fail) self._tx = None self._failure_handlers = {} async def __ainit__(self, auth): args = { "scheme": "none", "user_agent": self.default_user_agent(), } if auth: args.update({ "scheme": "basic", "principal": auth[0], # TODO "credentials": auth[1], # TODO }) response = self._courier.write_hello(args) await self._courier.send() summary = await response.get_summary() if summary.success: self.server_agent = summary.metadata.get("server") self.connection_id = summary.metadata.get("connection_id") # TODO: verify genuine product else: await super().close() code = summary.metadata.get("code") message = summary.metadata.get("message") failure = BoltFailure(message, self.remote_address, code, response) self.fail(failure) async def close(self): if self.closed: return if not self.broken: self._courier.write_goodbye() try: await self._courier.send() except BoltConnectionBroken: pass await super().close() @property def ready(self): """ If true, this flag indicates that there is no transaction in progress, and one may be started. """ return not self._tx or self._tx.closed def _assert_open(self): if self.closed: raise BoltConnectionClosed("Connection has been closed", self.remote_address) if self.broken: raise BoltConnectionBroken("Connection is broken", self.remote_address) def _assert_ready(self): self._assert_open() if not self.ready: # TODO: add transaction identifier raise BoltTransactionError( "A transaction is already in progress on " "this connection", self.remote_address) async def reset(self, force=False): self._assert_open() if force or not self.ready: self._courier.write_reset() if self._courier.requests_to_send: await self._courier.send() if self._courier.responses_to_fetch: await self._courier.fetch() async def run(self, cypher, parameters=None, discard=False, readonly=False, bookmarks=None, timeout=None, metadata=None): self._assert_ready() self._tx = Transaction(self._courier, readonly=readonly, bookmarks=bookmarks, timeout=timeout, metadata=metadata) return await self._tx.run(cypher, parameters, discard=discard) async def begin(self, readonly=False, bookmarks=None, timeout=None, metadata=None): self._assert_ready() self._tx = await Transaction.begin(self._courier, readonly=readonly, bookmarks=bookmarks, timeout=timeout, metadata=metadata) return self._tx async def run_tx(self, f, args=None, kwargs=None, readonly=False, bookmarks=None, timeout=None, metadata=None): self._assert_open() tx = await self.begin(readonly=readonly, bookmarks=bookmarks, timeout=None, metadata=metadata) if not iscoroutinefunction(f): raise TypeError("Transaction function must be awaitable") try: value = await f(tx, *(args or ()), **(kwargs or {})) except Exception: await tx.rollback() raise else: await tx.commit() return value async def get_routing_table(self, context=None): try: result = await self.run( "CALL dbms.cluster.routing.getRoutingTable($context)", {"context": dict(context or {})}) record = await result.single() if not record: raise BoltRoutingError( "Routing table call returned " "no data", self.remote_address) assert isinstance(record, Record) servers = record["servers"] ttl = record["ttl"] log.debug("[#%04X] S: <ROUTING> servers=%r ttl=%r", self.local_address.port_number, servers, ttl) return RoutingTable.parse_routing_info(servers, ttl) except BoltFailure as error: if error.title == "ProcedureNotFound": raise BoltRoutingError("Server does not support " "routing", self.remote_address) from error else: raise except ValueError as error: raise BoltRoutingError("Invalid routing table", self.remote_address) from error def fail(self, failure): t = type(failure) handler = self.get_failure_handler(t) if callable(handler): # TODO: fix "requires two params, only one was given" error handler(failure) else: raise failure def get_failure_handler(self, cls): return self._failure_handlers.get(cls) def set_failure_handler(self, cls, f): self._failure_handlers[cls] = f def del_failure_handler(self, cls): try: del self._failure_handlers[cls] except KeyError: pass
class Bolt3(Bolt): protocol_version = Version(3, 0) server_agent = None connection_id = None def __init__(self, reader, writer): self._courier = Courier(reader, writer) self._tx = None async def __ainit__(self, auth): args = { "scheme": "none", "user_agent": self.default_user_agent(), } if auth: args.update({ "scheme": "basic", "principal": auth[0], # TODO "credentials": auth[1], # TODO }) response = self._courier.write_hello(args) await self._courier.send() summary = await response.get_summary() if summary.success: self.server_agent = summary.metadata.get("server") self.connection_id = summary.metadata.get("connection_id") # TODO: verify genuine product else: await super().close() code = summary.metadata.get("code") message = summary.metadata.get("message") raise BoltFailure(message, self.remote_address, code, response) async def close(self): if self.closed: return if not self.broken: self._courier.write_goodbye() try: await self._courier.send() except BoltConnectionBroken: pass await super().close() @property def ready(self): """ If true, this flag indicates that there is no transaction in progress, and one may be started. """ return not self._tx or self._tx.closed def _assert_open(self): if self.closed: raise BoltConnectionClosed("Connection has been closed", self.remote_address) if self.broken: raise BoltConnectionBroken("Connection is broken", self.remote_address) def _assert_ready(self): self._assert_open() if not self.ready: # TODO: add transaction identifier raise BoltTransactionError( "A transaction is already in progress on " "this connection", self.remote_address) async def reset(self, force=False): self._assert_open() if force or not self.ready: self._courier.write_reset() if self._courier.requests_to_send: await self._courier.send() if self._courier.responses_to_fetch: await self._courier.fetch() async def run(self, cypher, parameters=None, discard=False, readonly=False, bookmarks=None, timeout=None, metadata=None): self._assert_ready() self._tx = Transaction(self._courier, readonly=readonly, bookmarks=bookmarks, timeout=timeout, metadata=metadata) return await self._tx.run(cypher, parameters, discard=discard) async def begin(self, readonly=False, bookmarks=None, timeout=None, metadata=None): self._assert_ready() self._tx = await Transaction.begin(self._courier, readonly=readonly, bookmarks=bookmarks, timeout=timeout, metadata=metadata) return self._tx async def run_tx(self, f, args=None, kwargs=None, readonly=False, bookmarks=None, timeout=None, metadata=None): self._assert_open() tx = await self.begin(readonly=readonly, bookmarks=bookmarks, timeout=None, metadata=metadata) if not iscoroutinefunction(f): raise TypeError("Transaction function must be awaitable") try: value = await f(tx, *(args or ()), **(kwargs or {})) except Exception: await tx.rollback() raise else: await tx.commit() return value