예제 #1
0
class Bolt4x3(Bolt4x2):
    """ Protocol handler for Bolt 4.3.

    This is supported by Neo4j version 4.3.
    """

    PROTOCOL_VERSION = Version(4, 3)

    def route(self, database=None, bookmarks=None):

        def fail(md):
            from neo4j._exceptions import BoltRoutingError
            code = md.get("code")
            if code == "Neo.ClientError.Database.DatabaseNotFound":
                return  # surface this error to the user
            elif code == "Neo.ClientError.Procedure.ProcedureNotFound":
                raise BoltRoutingError("Server does not support routing", self.unresolved_address)
            else:
                raise BoltRoutingError("Routing support broken on server", self.unresolved_address)

        routing_context = self.routing_context or {}
        log.debug("[#%04X]  C: ROUTE %r %r", self.local_port, routing_context, database)
        metadata = {}
        if bookmarks is None:
            bookmarks = []
        else:
            bookmarks = list(bookmarks)
        self._append(b"\x66", (routing_context, bookmarks, database),
                     response=Response(self, on_success=metadata.update, on_failure=fail))
        self.send_all()
        self.fetch_all()
        return [metadata.get("rt")]

    def hello(self):
        def on_success(metadata):
            self.configuration_hints.update(metadata.pop("hints", {}))
            self.server_info.update(metadata)
            if "connection.recv_timeout_seconds" in self.configuration_hints:
                recv_timeout = self.configuration_hints[
                    "connection.recv_timeout_seconds"
                ]
                if isinstance(recv_timeout, int) and recv_timeout > 0:
                    self.socket.settimeout(recv_timeout)
                else:
                    log.info("[#%04X]  Server supplied an invalid value for "
                             "connection.recv_timeout_seconds (%r). Make sure "
                             "the server and network is set up correctly.",
                             self.local_port, recv_timeout)

        headers = self.get_base_headers()
        headers.update(self.auth_dict)
        logged_headers = dict(headers)
        if "credentials" in logged_headers:
            logged_headers["credentials"] = "*******"
        log.debug("[#%04X]  C: HELLO %r", self.local_port, logged_headers)
        self._append(b"\x01", (headers,),
                     response=InitResponse(self, on_success=on_success))
        self.send_all()
        self.fetch_all()
        check_supported_server_product(self.server_info.agent)
예제 #2
0
class Bolt4x3(Bolt4x2):
    """ Protocol handler for Bolt 4.3.

    This is supported by Neo4j version 4.3.
    """

    PROTOCOL_VERSION = Version(4, 3)

    def route(self, database=None, bookmarks=None):

        def fail(md):
            from neo4j._exceptions import BoltRoutingError
            code = md.get("code")
            if code == "Neo.ClientError.Database.DatabaseNotFound":
                return  # surface this error to the user
            elif code == "Neo.ClientError.Procedure.ProcedureNotFound":
                raise BoltRoutingError("Server does not support routing", self.unresolved_address)
            else:
                raise BoltRoutingError("Routing support broken on server", self.unresolved_address)

        routing_context = self.routing_context or {}
        log.debug("[#%04X]  C: ROUTE %r %r", self.local_port, routing_context, database)
        metadata = {}
        if bookmarks is None:
            bookmarks = []
        else:
            bookmarks = list(bookmarks)
        self._append(b"\x66", (routing_context, bookmarks, database),
                     response=Response(self, on_success=metadata.update, on_failure=fail))
        self.send_all()
        self.fetch_all()
        return [metadata.get("rt")]
예제 #3
0
class Bolt4x2(Bolt4x1):
    """ Protocol handler for Bolt 4.2.

    This is supported by Neo4j version 4.2.
    """

    PROTOCOL_VERSION = Version(4, 2)
예제 #4
0
 def version_list(cls, versions, limit=4):
     """ Return a list of supported protocol versions in order of
     preference. The number of protocol versions (or ranges)
     returned is limited to four.
     """
     ranges_supported = versions[0] >= Version(4, 3)
     if versions and ranges_supported:
         start, end = 0, 0
         first_major = versions[start][0]
         minors = []
         for end, version in enumerate(versions):
             if version[0] == first_major:
                 minors.append(version[1])
             else:
                 break
         new_versions = ([Version(first_major, minors)] +
                         versions[1:end])[:(limit - 1)]
         try:
             new_versions.append(versions[end])
         except IndexError:
             pass
         return new_versions
     else:
         return versions[:limit]
예제 #5
0
    async def _handshake(cls, reader, writer, protocol_version):
        """ Carry out a Bolt handshake, optionally requesting a
        specific protocol version.

        :param reader:
        :param writer:
        :param protocol_version:
        :return:
        :raise BoltConnectionLost: if an I/O error occurs on the
            underlying socket connection
        :raise BoltHandshakeError: if handshake completes without a
            successful negotiation
        """
        local_address = Address(writer.transport.get_extra_info("sockname"))
        remote_address = Address(writer.transport.get_extra_info("peername"))

        handlers = cls.protocol_handlers(protocol_version)
        if not handlers:
            raise ValueError(
                "No protocol handlers available (requested Bolt %r)",
                protocol_version)
        offered_versions = sorted(handlers.keys(), reverse=True)[:4]

        request_data = MAGIC + b"".join(
            v.to_bytes() for v in offered_versions).ljust(16, b"\x00")
        log.debug("[#%04X] C: <HANDSHAKE> %r", local_address.port_number,
                  request_data)
        writer.write(request_data)
        await writer.drain()
        response_data = await reader.readexactly(4)
        log.debug("[#%04X] S: <HANDSHAKE> %r", local_address.port_number,
                  response_data)
        try:
            agreed_version = Version.from_bytes(response_data)
        except ValueError as err:
            writer.close()
            raise BoltHandshakeError(
                "Unexpected handshake response %r" % response_data,
                remote_address, request_data, response_data) from err
        try:
            subclass = handlers[agreed_version]
        except KeyError:
            log.debug("Unsupported Bolt protocol version %s", agreed_version)
            raise BoltHandshakeError("Unsupported Bolt protocol version",
                                     remote_address, request_data,
                                     response_data)
        else:
            return subclass
예제 #6
0
class Bolt4x1(Bolt4x0):
    """ Protocol handler for Bolt 4.1.

    This is supported by Neo4j versions 4.1 and 4.2.
    """

    PROTOCOL_VERSION = Version(4, 1)

    def get_base_headers(self):
        """ Bolt 4.1 passes the routing context, originally taken from
        the URI, into the connection initialisation message. This
        enables server-side routing to propagate the same behaviour
        through its driver.
        """
        return {
            "user_agent": self.user_agent,
            "routing": self.routing_context,
        }
예제 #7
0
    def supports_multi_db(self):
        """ Check if the server or cluster supports multi-databases.

        :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false.
        :rtype: bool
        """
        from neo4j.io._bolt4x0 import Bolt4x0

        multi_database = False
        cx = self._pool.acquire(
            access_mode=READ_ACCESS,
            timeout=self._pool.workspace_config.connection_acquisition_timeout,
            database=self._pool.workspace_config.database)

        # TODO: This logic should be inside the Bolt subclasses, because it can change depending on Bolt Protocol Version.
        if cx.PROTOCOL_VERSION >= Bolt4x0.PROTOCOL_VERSION and cx.server_info.version_info(
        ) >= Version(4, 0, 0):
            multi_database = True

        self._pool.release(cx)

        return multi_database
        result = session.run("X")
        list(result.keys())


def test_should_not_allow_empty_statements(session):
    with raises(ValueError):
        _ = session.run("")


def test_statement_object(session):
    value = session.run(Statement("RETURN $x"), x=1).single().value()
    assert value == 1


@pytest.mark.parametrize("test_input, neo4j_version", [
    ("CALL dbms.getTXMetaData", Version(3, 0)),
    ("CALL tx.getMetaData", Version(4, 0)),
])
def test_autocommit_transactions_should_support_metadata(
        session, test_input, neo4j_version):
    # python -m pytest tests/integration/test_autocommit.py -s -r fEsxX -k test_autocommit_transactions_should_support_metadata
    metadata_in = {"foo": "bar"}

    result = session.run("RETURN 1")
    value = result.single().value()
    summary = result.summary()
    server_agent = summary.server.agent

    try:
        statement = Statement(test_input, metadata=metadata_in)
        result = session.run(statement)
예제 #9
0
class Bolt4x0(Bolt):
    """ Protocol handler for Bolt 4.0.

    This is supported by Neo4j versions 4.0, 4.1 and 4.2.
    """

    PROTOCOL_VERSION = Version(4, 0)

    supports_multiple_results = True

    supports_multiple_databases = True

    @property
    def encrypted(self):
        return isinstance(self.socket, SSLSocket)

    @property
    def der_encoded_server_certificate(self):
        return self.socket.getpeercert(binary_form=True)

    @property
    def local_port(self):
        try:
            return self.socket.getsockname()[1]
        except OSError:
            return 0

    def get_base_headers(self):
        return {
            "user_agent": self.user_agent,
        }

    def hello(self):
        headers = self.get_base_headers()
        headers.update(self.auth_dict)
        logged_headers = dict(headers)
        if "credentials" in logged_headers:
            logged_headers["credentials"] = "*******"
        log.debug("[#%04X]  C: HELLO %r", self.local_port, logged_headers)
        self._append(b"\x01", (headers,),
                     response=InitResponse(self, on_success=self.server_info.update))
        self.send_all()
        self.fetch_all()
        check_supported_server_product(self.server_info.agent)

    def route(self, database=None, bookmarks=None):
        metadata = {}
        records = []

        def fail(md):
            from neo4j._exceptions import BoltRoutingError
            code = md.get("code")
            if code == "Neo.ClientError.Database.DatabaseNotFound":
                return  # surface this error to the user
            elif code == "Neo.ClientError.Procedure.ProcedureNotFound":
                raise BoltRoutingError("Server does not support routing", self.unresolved_address)
            else:
                raise BoltRoutingError("Routing support broken on server", self.unresolved_address)

        if database is None:  # default database
            self.run(
                "CALL dbms.routing.getRoutingTable($context)",
                {"context": self.routing_context},
                mode="r",
                bookmarks=bookmarks,
                db=SYSTEM_DATABASE,
                on_success=metadata.update, on_failure=fail
            )
        else:
            self.run(
                "CALL dbms.routing.getRoutingTable($context, $database)",
                {"context": self.routing_context, "database": database},
                mode="r",
                bookmarks=bookmarks,
                db=SYSTEM_DATABASE,
                on_success=metadata.update, on_failure=fail
            )
        self.pull(on_success=metadata.update, on_records=records.extend)
        self.send_all()
        self.fetch_all()
        routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records]
        return routing_info

    def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
        if not parameters:
            parameters = {}
        extra = {}
        if mode in (READ_ACCESS, "r"):
            extra["mode"] = "r"  # It will default to mode "w" if nothing is specified
        if db:
            extra["db"] = db
        if bookmarks:
            try:
                extra["bookmarks"] = list(bookmarks)
            except TypeError:
                raise TypeError("Bookmarks must be provided within an iterable")
        if metadata:
            try:
                extra["tx_metadata"] = dict(metadata)
            except TypeError:
                raise TypeError("Metadata must be coercible to a dict")
        if timeout:
            try:
                extra["tx_timeout"] = int(1000 * timeout)
            except TypeError:
                raise TypeError("Timeout must be specified as a number of seconds")
        fields = (query, parameters, extra)
        log.debug("[#%04X]  C: RUN %s", self.local_port, " ".join(map(repr, fields)))
        if query.upper() == u"COMMIT":
            self._append(b"\x10", fields, CommitResponse(self, **handlers))
        else:
            self._append(b"\x10", fields, Response(self, **handlers))
        self._is_reset = False

    def discard(self, n=-1, qid=-1, **handlers):
        extra = {"n": n}
        if qid != -1:
            extra["qid"] = qid
        log.debug("[#%04X]  C: DISCARD %r", self.local_port, extra)
        self._append(b"\x2F", (extra,), Response(self, **handlers))

    def pull(self, n=-1, qid=-1, **handlers):
        extra = {"n": n}
        if qid != -1:
            extra["qid"] = qid
        log.debug("[#%04X]  C: PULL %r", self.local_port, extra)
        self._append(b"\x3F", (extra,), Response(self, **handlers))
        self._is_reset = False

    def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
              db=None, **handlers):
        extra = {}
        if mode in (READ_ACCESS, "r"):
            extra["mode"] = "r"  # It will default to mode "w" if nothing is specified
        if db:
            extra["db"] = db
        if bookmarks:
            try:
                extra["bookmarks"] = list(bookmarks)
            except TypeError:
                raise TypeError("Bookmarks must be provided within an iterable")
        if metadata:
            try:
                extra["tx_metadata"] = dict(metadata)
            except TypeError:
                raise TypeError("Metadata must be coercible to a dict")
        if timeout:
            try:
                extra["tx_timeout"] = int(1000 * timeout)
            except TypeError:
                raise TypeError("Timeout must be specified as a number of seconds")
        log.debug("[#%04X]  C: BEGIN %r", self.local_port, extra)
        self._append(b"\x11", (extra,), Response(self, **handlers))
        self._is_reset = False

    def commit(self, **handlers):
        log.debug("[#%04X]  C: COMMIT", self.local_port)
        self._append(b"\x12", (), CommitResponse(self, **handlers))

    def rollback(self, **handlers):
        log.debug("[#%04X]  C: ROLLBACK", self.local_port)
        self._append(b"\x13", (), Response(self, **handlers))

    def reset(self):
        """ Add a RESET message to the outgoing queue, send
        it and consume all remaining messages.
        """

        def fail(metadata):
            raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address)

        log.debug("[#%04X]  C: RESET", self.local_port)
        self._append(b"\x0F", response=Response(self, on_failure=fail))
        self.send_all()
        self.fetch_all()
        self._is_reset = True

    def fetch_message(self):
        """ Receive at most one message from the server, if available.

        :return: 2-tuple of number of detail messages and number of summary
                 messages fetched
        """
        if self._closed:
            raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        if self._defunct:
            raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        if not self.responses:
            return 0, 0

        # Receive exactly one message
        details, summary_signature, summary_metadata = next(self.inbox)

        if details:
            log.debug("[#%04X]  S: RECORD * %d", self.local_port, len(details))  # Do not log any data
            self.responses[0].on_records(details)

        if summary_signature is None:
            return len(details), 0

        response = self.responses.popleft()
        response.complete = True
        if summary_signature == b"\x70":
            log.debug("[#%04X]  S: SUCCESS %r", self.local_port, summary_metadata)
            response.on_success(summary_metadata or {})
        elif summary_signature == b"\x7E":
            log.debug("[#%04X]  S: IGNORED", self.local_port)
            response.on_ignored(summary_metadata or {})
        elif summary_signature == b"\x7F":
            log.debug("[#%04X]  S: FAILURE %r", self.local_port, summary_metadata)
            try:
                response.on_failure(summary_metadata or {})
            except (ServiceUnavailable, DatabaseUnavailable):
                if self.pool:
                    self.pool.deactivate(address=self.unresolved_address),
                raise
            except (NotALeader, ForbiddenOnReadOnlyDatabase):
                if self.pool:
                    self.pool.on_write_failure(address=self.unresolved_address),
                raise
            except Neo4jError as e:
                if self.pool and e.invalidates_all_connections():
                    self.pool.mark_all_stale()
                raise
        else:
            raise BoltProtocolError("Unexpected response message with signature "
                                    "%02X" % ord(summary_signature), self.unresolved_address)

        return len(details), 1

    def close(self):
        """ Close the connection.
        """
        if not self._closed:
            if not self._defunct:
                log.debug("[#%04X]  C: GOODBYE", self.local_port)
                self._append(b"\x02", ())
                try:
                    self._send_all()
                except (OSError, BoltError, DriverError):
                    pass
            log.debug("[#%04X]  C: <CLOSE>", self.local_port)
            try:
                self.socket.close()
            except OSError:
                pass
            finally:
                self._closed = True

    def closed(self):
        return self._closed

    def defunct(self):
        return self._defunct
예제 #10
0
class Bolt3(Bolt):

    PROTOCOL_VERSION = Version(3, 0)

    # The socket
    in_use = False

    # The socket
    _closed = False

    # The socket
    _defunct = False

    #: The pool of which this connection is a member
    pool = None

    def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None):
        self.unresolved_address = unresolved_address
        self.socket = sock
        self.server_info = ServerInfo(Address(sock.getpeername()), Bolt3.PROTOCOL_VERSION)
        self.outbox = Outbox()
        self.inbox = Inbox(self.socket, on_error=self._set_defunct)
        self.packer = Packer(self.outbox)
        self.unpacker = Unpacker(self.inbox)
        self.responses = deque()
        self._max_connection_lifetime = max_connection_lifetime
        self._creation_timestamp = perf_counter()
        self.supports_multiple_results = False
        self.supports_multiple_databases = False
        self._is_reset = True

        # Determine the user agent
        if user_agent:
            self.user_agent = user_agent
        else:
            self.user_agent = get_user_agent()

        # Determine auth details
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j import Auth
            self.auth_dict = vars(Auth("basic", *auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise AuthError("Cannot determine auth details from %r" % auth)

        # Check for missing password
        try:
            credentials = self.auth_dict["credentials"]
        except KeyError:
            pass
        else:
            if credentials is None:
                raise AuthError("Password cannot be None")

    @property
    def encrypted(self):
        return isinstance(self.socket, SSLSocket)

    @property
    def der_encoded_server_certificate(self):
        return self.socket.getpeercert(binary_form=True)

    @property
    def local_port(self):
        try:
            return self.socket.getsockname()[1]
        except IOError:
            return 0

    def hello(self):
        headers = {"user_agent": self.user_agent}
        headers.update(self.auth_dict)
        logged_headers = dict(headers)
        if "credentials" in logged_headers:
            logged_headers["credentials"] = "*******"
        log.debug("[#%04X]  C: HELLO %r", self.local_port, logged_headers)
        self._append(b"\x01", (headers,),
                     response=InitResponse(self, on_success=self.server_info.metadata.update))
        self.send_all()
        self.fetch_all()

    def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
        if db is not None:
            raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db))
        if not parameters:
            parameters = {}
        extra = {}
        if mode in (READ_ACCESS, "r"):
            extra["mode"] = "r"  # It will default to mode "w" if nothing is specified
        if bookmarks:
            try:
                extra["bookmarks"] = list(bookmarks)
            except TypeError:
                raise TypeError("Bookmarks must be provided within an iterable")
        if metadata:
            try:
                extra["tx_metadata"] = dict(metadata)
            except TypeError:
                raise TypeError("Metadata must be coercible to a dict")
        if timeout:
            try:
                extra["tx_timeout"] = int(1000 * timeout)
            except TypeError:
                raise TypeError("Timeout must be specified as a number of seconds")
        fields = (query, parameters, extra)
        log.debug("[#%04X]  C: RUN %s", self.local_port, " ".join(map(repr, fields)))
        if query.upper() == u"COMMIT":
            self._append(b"\x10", fields, CommitResponse(self, **handlers))
        else:
            self._append(b"\x10", fields, Response(self, **handlers))
        self._is_reset = False

    def discard(self, n=-1, qid=-1, **handlers):
        # Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
        log.debug("[#%04X]  C: DISCARD_ALL", self.local_port)
        self._append(b"\x2F", (), Response(self, **handlers))

    def pull(self, n=-1, qid=-1, **handlers):
        # Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
        log.debug("[#%04X]  C: PULL_ALL", self.local_port)
        self._append(b"\x3F", (), Response(self, **handlers))
        self._is_reset = False

    def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
        if db is not None:
            raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db))
        extra = {}
        if mode in (READ_ACCESS, "r"):
            extra["mode"] = "r"  # It will default to mode "w" if nothing is specified
        if bookmarks:
            try:
                extra["bookmarks"] = list(bookmarks)
            except TypeError:
                raise TypeError("Bookmarks must be provided within an iterable")
        if metadata:
            try:
                extra["tx_metadata"] = dict(metadata)
            except TypeError:
                raise TypeError("Metadata must be coercible to a dict")
        if timeout:
            try:
                extra["tx_timeout"] = int(1000 * timeout)
            except TypeError:
                raise TypeError("Timeout must be specified as a number of seconds")
        log.debug("[#%04X]  C: BEGIN %r", self.local_port, extra)
        self._append(b"\x11", (extra,), Response(self, **handlers))
        self._is_reset = False

    def commit(self, **handlers):
        log.debug("[#%04X]  C: COMMIT", self.local_port)
        self._append(b"\x12", (), CommitResponse(self, **handlers))

    def rollback(self, **handlers):
        log.debug("[#%04X]  C: ROLLBACK", self.local_port)
        self._append(b"\x13", (), Response(self, **handlers))

    def _append(self, signature, fields=(), response=None):
        """ Add a message to the outgoing queue.

        :arg signature: the signature of the message
        :arg fields: the fields of the message as a tuple
        :arg response: a response object to handle callbacks
        """
        self.packer.pack_struct(signature, fields)
        self.outbox.chunk()
        self.outbox.chunk()
        self.responses.append(response)

    def reset(self):
        """ Add a RESET message to the outgoing queue, send
        it and consume all remaining messages.
        """

        def fail(metadata):
            raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)

        log.debug("[#%04X]  C: RESET", self.local_port)
        self._append(b"\x0F", response=Response(self, on_failure=fail))
        self.send_all()
        self.fetch_all()
        self._is_reset = True

    def _send_all(self):
        data = self.outbox.view()
        if data:
            self.socket.sendall(data)
            self.outbox.clear()

    def send_all(self):
        """ Send all queued messages to the server.
        """
        if self.closed():
            raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        if self.defunct():
            raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        try:
            self._send_all()
        except (IOError, OSError) as error:
            log.error("Failed to write data to connection "
                      "{!r} ({!r}); ({!r})".
                      format(self.unresolved_address,
                             self.server_info.address,
                             "; ".join(map(repr, error.args))))
            if self.pool:
                self.pool.deactivate(address=self.unresolved_address)
            raise

    def fetch_message(self):
        """ Receive at least one message from the server, if available.

        :return: 2-tuple of number of detail messages and number of summary
                 messages fetched
        """
        if self._closed:
            raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        if self._defunct:
            raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        if not self.responses:
            return 0, 0

        # Receive exactly one message
        try:
            details, summary_signature, summary_metadata = next(self.inbox)
        except (IOError, OSError) as error:
            log.error("Failed to read data from connection "
                      "{!r} ({!r}); ({!r})".
                      format(self.unresolved_address,
                             self.server_info.address,
                             "; ".join(map(repr, error.args))))
            if self.pool:
                self.pool.deactivate(address=self.unresolved_address)
            raise

        if details:
            log.debug("[#%04X]  S: RECORD * %d", self.local_port, len(details))  # Do not log any data
            self.responses[0].on_records(details)

        if summary_signature is None:
            return len(details), 0

        response = self.responses.popleft()
        response.complete = True
        if summary_signature == b"\x70":
            log.debug("[#%04X]  S: SUCCESS %r", self.local_port, summary_metadata)
            response.on_success(summary_metadata or {})
        elif summary_signature == b"\x7E":
            log.debug("[#%04X]  S: IGNORED", self.local_port)
            response.on_ignored(summary_metadata or {})
        elif summary_signature == b"\x7F":
            log.debug("[#%04X]  S: FAILURE %r", self.local_port, summary_metadata)
            try:
                response.on_failure(summary_metadata or {})
            except (ServiceUnavailable, DatabaseUnavailable):
                if self.pool:
                    self.pool.deactivate(address=self.unresolved_address),
                raise
            except (NotALeader, ForbiddenOnReadOnlyDatabase):
                if self.pool:
                    self.pool.on_write_failure(address=self.unresolved_address),
                raise
        else:
            raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address)

        return len(details), 1

    def _set_defunct(self, error=None):
        direct_driver = isinstance(self.pool, BoltPool)

        message = ("Failed to read from defunct connection {!r} ({!r})".format(
            self.unresolved_address, self.server_info.address))

        log.error(message)
        # We were attempting to receive data but the connection
        # has unexpectedly terminated. So, we need to close the
        # connection from the client side, and remove the address
        # from the connection pool.
        self._defunct = True
        self.close()
        if self.pool:
            self.pool.deactivate(address=self.unresolved_address)
        # Iterate through the outstanding responses, and if any correspond
        # to COMMIT requests then raise an error to signal that we are
        # unable to confirm that the COMMIT completed successfully.
        for response in self.responses:
            if isinstance(response, CommitResponse):
                raise BoltIncompleteCommitError(message, address=None)

        if direct_driver:
            raise ServiceUnavailable(message)
        else:
            raise SessionExpired(message)

    def timedout(self):
        return 0 <= self._max_connection_lifetime <= perf_counter() - self._creation_timestamp

    def fetch_all(self):
        """ Fetch all outstanding messages.

        :return: 2-tuple of number of detail messages and number of summary
                 messages fetched
        """
        detail_count = summary_count = 0
        while self.responses:
            response = self.responses[0]
            while not response.complete:
                detail_delta, summary_delta = self.fetch_message()
                detail_count += detail_delta
                summary_count += summary_delta
        return detail_count, summary_count

    def close(self):
        """ Close the connection.
        """
        if not self._closed:
            if not self._defunct:
                log.debug("[#%04X]  C: GOODBYE", self.local_port)
                self._append(b"\x02", ())
                try:
                    self._send_all()
                except:
                    pass
            log.debug("[#%04X]  C: <CLOSE>", self.local_port)
            try:
                self.socket.close()
            except IOError:
                pass
            finally:
                self._closed = True

    def closed(self):
        return self._closed

    def defunct(self):
        return self._defunct
예제 #11
0
class Bolt3(Bolt):

    protocol_version = Version(3, 0)

    server_agent = None

    connection_id = None

    def __init__(self, reader, writer):
        self._courier = Courier(reader, writer, self.fail)
        self._tx = None
        self._failure_handlers = {}

    async def __ainit__(self, auth):
        args = {
            "scheme": "none",
            "user_agent": self.default_user_agent(),
        }
        if auth:
            args.update({
                "scheme": "basic",
                "principal": auth[0],  # TODO
                "credentials": auth[1],  # TODO
            })
        response = self._courier.write_hello(args)
        await self._courier.send()
        summary = await response.get_summary()
        if summary.success:
            self.server_agent = summary.metadata.get("server")
            self.connection_id = summary.metadata.get("connection_id")
            # TODO: verify genuine product
        else:
            await super().close()
            code = summary.metadata.get("code")
            message = summary.metadata.get("message")
            failure = BoltFailure(message, self.remote_address, code, response)
            self.fail(failure)

    async def close(self):
        if self.closed:
            return
        if not self.broken:
            self._courier.write_goodbye()
            try:
                await self._courier.send()
            except BoltConnectionBroken:
                pass
        await super().close()

    @property
    def ready(self):
        """ If true, this flag indicates that there is no transaction
        in progress, and one may be started.
        """
        return not self._tx or self._tx.closed

    def _assert_open(self):
        if self.closed:
            raise BoltConnectionClosed("Connection has been closed",
                                       self.remote_address)
        if self.broken:
            raise BoltConnectionBroken("Connection is broken",
                                       self.remote_address)

    def _assert_ready(self):
        self._assert_open()
        if not self.ready:
            # TODO: add transaction identifier
            raise BoltTransactionError(
                "A transaction is already in progress on "
                "this connection", self.remote_address)

    async def reset(self, force=False):
        self._assert_open()
        if force or not self.ready:
            self._courier.write_reset()
        if self._courier.requests_to_send:
            await self._courier.send()
        if self._courier.responses_to_fetch:
            await self._courier.fetch()

    async def run(self,
                  cypher,
                  parameters=None,
                  discard=False,
                  readonly=False,
                  bookmarks=None,
                  timeout=None,
                  metadata=None):
        self._assert_ready()
        self._tx = Transaction(self._courier,
                               readonly=readonly,
                               bookmarks=bookmarks,
                               timeout=timeout,
                               metadata=metadata)
        return await self._tx.run(cypher, parameters, discard=discard)

    async def begin(self,
                    readonly=False,
                    bookmarks=None,
                    timeout=None,
                    metadata=None):
        self._assert_ready()
        self._tx = await Transaction.begin(self._courier,
                                           readonly=readonly,
                                           bookmarks=bookmarks,
                                           timeout=timeout,
                                           metadata=metadata)
        return self._tx

    async def run_tx(self,
                     f,
                     args=None,
                     kwargs=None,
                     readonly=False,
                     bookmarks=None,
                     timeout=None,
                     metadata=None):
        self._assert_open()
        tx = await self.begin(readonly=readonly,
                              bookmarks=bookmarks,
                              timeout=None,
                              metadata=metadata)
        if not iscoroutinefunction(f):
            raise TypeError("Transaction function must be awaitable")
        try:
            value = await f(tx, *(args or ()), **(kwargs or {}))
        except Exception:
            await tx.rollback()
            raise
        else:
            await tx.commit()
            return value

    async def get_routing_table(self, context=None):
        try:
            result = await self.run(
                "CALL dbms.cluster.routing.getRoutingTable($context)",
                {"context": dict(context or {})})
            record = await result.single()
            if not record:
                raise BoltRoutingError(
                    "Routing table call returned "
                    "no data", self.remote_address)
            assert isinstance(record, Record)
            servers = record["servers"]
            ttl = record["ttl"]
            log.debug("[#%04X] S: <ROUTING> servers=%r ttl=%r",
                      self.local_address.port_number, servers, ttl)
            return RoutingTable.parse_routing_info(servers, ttl)
        except BoltFailure as error:
            if error.title == "ProcedureNotFound":
                raise BoltRoutingError("Server does not support "
                                       "routing",
                                       self.remote_address) from error
            else:
                raise
        except ValueError as error:
            raise BoltRoutingError("Invalid routing table",
                                   self.remote_address) from error

    def fail(self, failure):
        t = type(failure)
        handler = self.get_failure_handler(t)
        if callable(handler):
            # TODO: fix "requires two params, only one was given" error
            handler(failure)
        else:
            raise failure

    def get_failure_handler(self, cls):
        return self._failure_handlers.get(cls)

    def set_failure_handler(self, cls, f):
        self._failure_handlers[cls] = f

    def del_failure_handler(self, cls):
        try:
            del self._failure_handlers[cls]
        except KeyError:
            pass
예제 #12
0
class Bolt3(Bolt):

    protocol_version = Version(3, 0)

    server_agent = None

    connection_id = None

    def __init__(self, reader, writer):
        self._courier = Courier(reader, writer)
        self._tx = None

    async def __ainit__(self, auth):
        args = {
            "scheme": "none",
            "user_agent": self.default_user_agent(),
        }
        if auth:
            args.update({
                "scheme": "basic",
                "principal": auth[0],  # TODO
                "credentials": auth[1],  # TODO
            })
        response = self._courier.write_hello(args)
        await self._courier.send()
        summary = await response.get_summary()
        if summary.success:
            self.server_agent = summary.metadata.get("server")
            self.connection_id = summary.metadata.get("connection_id")
            # TODO: verify genuine product
        else:
            await super().close()
            code = summary.metadata.get("code")
            message = summary.metadata.get("message")
            raise BoltFailure(message, self.remote_address, code, response)

    async def close(self):
        if self.closed:
            return
        if not self.broken:
            self._courier.write_goodbye()
            try:
                await self._courier.send()
            except BoltConnectionBroken:
                pass
        await super().close()

    @property
    def ready(self):
        """ If true, this flag indicates that there is no transaction
        in progress, and one may be started.
        """
        return not self._tx or self._tx.closed

    def _assert_open(self):
        if self.closed:
            raise BoltConnectionClosed("Connection has been closed",
                                       self.remote_address)
        if self.broken:
            raise BoltConnectionBroken("Connection is broken",
                                       self.remote_address)

    def _assert_ready(self):
        self._assert_open()
        if not self.ready:
            # TODO: add transaction identifier
            raise BoltTransactionError(
                "A transaction is already in progress on "
                "this connection", self.remote_address)

    async def reset(self, force=False):
        self._assert_open()
        if force or not self.ready:
            self._courier.write_reset()
        if self._courier.requests_to_send:
            await self._courier.send()
        if self._courier.responses_to_fetch:
            await self._courier.fetch()

    async def run(self,
                  cypher,
                  parameters=None,
                  discard=False,
                  readonly=False,
                  bookmarks=None,
                  timeout=None,
                  metadata=None):
        self._assert_ready()
        self._tx = Transaction(self._courier,
                               readonly=readonly,
                               bookmarks=bookmarks,
                               timeout=timeout,
                               metadata=metadata)
        return await self._tx.run(cypher, parameters, discard=discard)

    async def begin(self,
                    readonly=False,
                    bookmarks=None,
                    timeout=None,
                    metadata=None):
        self._assert_ready()
        self._tx = await Transaction.begin(self._courier,
                                           readonly=readonly,
                                           bookmarks=bookmarks,
                                           timeout=timeout,
                                           metadata=metadata)
        return self._tx

    async def run_tx(self,
                     f,
                     args=None,
                     kwargs=None,
                     readonly=False,
                     bookmarks=None,
                     timeout=None,
                     metadata=None):
        self._assert_open()
        tx = await self.begin(readonly=readonly,
                              bookmarks=bookmarks,
                              timeout=None,
                              metadata=metadata)
        if not iscoroutinefunction(f):
            raise TypeError("Transaction function must be awaitable")
        try:
            value = await f(tx, *(args or ()), **(kwargs or {}))
        except Exception:
            await tx.rollback()
            raise
        else:
            await tx.commit()
            return value