コード例 #1
0
    def __init__(self, address, sock, protocol_version, error_handler, **config):
        self.address = address
        self.socket = sock
        self.protocol_version = protocol_version
        self.error_handler = error_handler
        self.server = ServerInfo(SocketAddress.from_socket(sock))
        self.input_buffer = ChunkedInputBuffer()
        self.output_buffer = ChunkedOutputBuffer()
        self.packer = Packer(self.output_buffer)
        self.unpacker = Unpacker()
        self.responses = deque()
        self._max_connection_lifetime = config.get("max_connection_lifetime", default_config["max_connection_lifetime"])
        self._creation_timestamp = perf_counter()

        # Determine the user agent and ensure it is a Unicode value
        user_agent = config.get("user_agent", default_config["user_agent"])
        if isinstance(user_agent, bytes):
            user_agent = user_agent.decode("UTF-8")
        self.user_agent = user_agent

        # Determine auth details
        auth = config.get("auth")
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j.v1 import basic_auth
            self.auth_dict = vars(basic_auth(*auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise TypeError("Cannot determine auth details from %r" % auth)

        # Pick up the server certificate, if any
        self.der_encoded_server_certificate = config.get("der_encoded_server_certificate")
コード例 #2
0
    def test_dehydration_2d(self):
        coordinates = (.1, 0)
        p = CartesianPoint(coordinates)

        dehydrator = DataDehydrator()
        buffer = io.BytesIO()
        packer = Packer(buffer)
        packer.pack(dehydrator.dehydrate((p, ))[0])
        self.assertEqual(
            buffer.getvalue(),
            b"\xB3X" + b"\xC9" + struct.pack(">h", 7203) + b"".join(
                map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)))
コード例 #3
0
    def test_dehydration_3d(self):
        coordinates = (1, -2, 3.1)
        p = WGS84Point(coordinates)

        dehydrator = DataDehydrator()
        buffer = io.BytesIO()
        packer = Packer(buffer)
        packer.pack(dehydrator.dehydrate((p, ))[0])
        self.assertEqual(
            buffer.getvalue(),
            b"\xB4Y" + b"\xC9" + struct.pack(">h", 4979) + b"".join(
                map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)))
コード例 #4
0
    def test_dehydration(self):
        MyPoint = point_type("MyPoint", ["x", "y"], {2: 1234})
        coordinates = (.1, 0)
        p = MyPoint(coordinates)

        dehydrator = DataDehydrator()
        buffer = io.BytesIO()
        packer = Packer(buffer)
        packer.pack(dehydrator.dehydrate((p, ))[0])
        self.assertEqual(
            buffer.getvalue(),
            b"\xB3X" + b"\xC9" + struct.pack(">h", 1234) + b"".join(
                map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)))
コード例 #5
0
    def __init__(self,
                 unresolved_address,
                 sock,
                 max_connection_lifetime,
                 *,
                 auth=None,
                 user_agent=None,
                 routing_context=None):
        self.unresolved_address = unresolved_address
        self.socket = sock
        self.server_info = ServerInfo(Address(sock.getpeername()),
                                      self.PROTOCOL_VERSION)
        self.outbox = Outbox()
        self.inbox = Inbox(self.socket, on_error=self._set_defunct)
        self.packer = Packer(self.outbox)
        self.unpacker = Unpacker(self.inbox)
        self.responses = deque()
        self._max_connection_lifetime = max_connection_lifetime
        self._creation_timestamp = perf_counter()
        self.supports_multiple_results = False
        self.supports_multiple_databases = False
        self._is_reset = True
        self.routing_context = routing_context

        # Determine the user agent
        if user_agent:
            self.user_agent = user_agent
        else:
            self.user_agent = get_user_agent()

        # Determine auth details
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j import Auth
            self.auth_dict = vars(Auth("basic", *auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise AuthError("Cannot determine auth details from %r" % auth)

        # Check for missing password
        try:
            credentials = self.auth_dict["credentials"]
        except KeyError:
            pass
        else:
            if credentials is None:
                raise AuthError("Password cannot be None")
コード例 #6
0
 def assert_packable(cls, value, packed_value):
     stream_out = BytesIO()
     packer = Packer(stream_out)
     packer.pack(value)
     packed = stream_out.getvalue()
     try:
         assert packed == packed_value
     except AssertionError:
         raise AssertionError("Packed value %r is %r instead of expected %r" %
                              (value, packed, packed_value))
     unpacked = Unpacker(UnpackableBuffer(packed)).unpack()
     try:
         assert unpacked == value
     except AssertionError:
         raise AssertionError("Unpacked value %r is not equal to original %r" % (unpacked, value))
コード例 #7
0
    def __init__(self,
                 unresolved_address,
                 sock,
                 *,
                 auth=None,
                 protocol_version=None,
                 **config):
        self.config = PoolConfig.consume(config)
        self.protocol_version = protocol_version
        self.unresolved_address = unresolved_address
        self.socket = sock
        self.server = ServerInfo(Address(sock.getpeername()), protocol_version)
        self.outbox = Outbox()
        self.inbox = Inbox(BufferedSocket(self.socket, 32768),
                           on_error=self._set_defunct)
        self.packer = Packer(self.outbox)
        self.unpacker = Unpacker(self.inbox)
        self.responses = deque()
        self._max_connection_lifetime = self.config.max_age
        self._creation_timestamp = perf_counter()

        # Determine the user agent
        user_agent = self.config.user_agent
        if user_agent:
            self.user_agent = user_agent
        else:
            self.user_agent = get_user_agent()

        # Determine auth details
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j import Auth
            self.auth_dict = vars(Auth("basic", *auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise AuthError("Cannot determine auth details from %r" % auth)

        # Check for missing password
        try:
            credentials = self.auth_dict["credentials"]
        except KeyError:
            pass
        else:
            if credentials is None:
                raise AuthError("Password cannot be None")
コード例 #8
0
 def test_list_stream(self):
     packed_value = b"\xD7\x01\x02\x03\xDF"
     unpacked_value = [1, 2, 3]
     stream_out = BytesIO()
     packer = Packer(stream_out)
     packer.pack_list_stream_header()
     packer.pack(1)
     packer.pack(2)
     packer.pack(3)
     packer.pack_end_of_stream()
     packed = stream_out.getvalue()
     try:
         assert packed == packed_value
     except AssertionError:
         raise AssertionError("Packed value is %r instead of expected %r" %
                              (packed, packed_value))
     unpacked = Unpacker(UnpackableBuffer(packed)).unpack()
     try:
         assert unpacked == unpacked_value
     except AssertionError:
         raise AssertionError("Unpacked value %r is not equal to expected %r" %
                              (unpacked, unpacked_value))
コード例 #9
0
    def __init__(self, sock, **config):
        self.socket = sock
        self.server = ServerInfo(SocketAddress.from_socket(sock))
        self.input_buffer = ChunkedInputBuffer()
        self.output_buffer = ChunkedOutputBuffer()
        self.packer = Packer(self.output_buffer)
        self.unpacker = Unpacker()
        self.responses = deque()

        # Determine the user agent and ensure it is a Unicode value
        user_agent = config.get("user_agent", DEFAULT_USER_AGENT)
        if isinstance(user_agent, bytes):
            user_agent = user_agent.decode("UTF-8")
        self.user_agent = user_agent

        # Determine auth details
        auth = config.get("auth")
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j.v1 import basic_auth
            self.auth_dict = vars(basic_auth(*auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise TypeError("Cannot determine auth details from %r" % auth)

        # Pick up the server certificate, if any
        self.der_encoded_server_certificate = config.get(
            "der_encoded_server_certificate")

        response = InitResponse(self)
        self.append(INIT, (self.user_agent, self.auth_dict), response=response)
        self.sync()

        self._supports_statement_reuse = self.server.supports_statement_reuse()
        self.packer.supports_bytes = self.server.supports_bytes()
コード例 #10
0
class Connection(object):
    """ Server connection for Bolt protocol v1.

    A :class:`.Connection` should be constructed following a
    successful Bolt handshake and takes the socket over which
    the handshake was carried out.

    .. note:: logs at INFO level
    """

    #: The protocol version in use on this connection
    protocol_version = 0

    #: Server details for this connection
    server = None

    in_use = False

    _closed = False

    _defunct = False

    #: The pool of which this connection is a member
    pool = None

    #: Error class used for raising connection errors
    Error = ServiceUnavailable

    _supports_statement_reuse = False

    _last_run_statement = None

    def __init__(self, address, sock, protocol_version, error_handler,
                 **config):
        self.address = address
        self.socket = sock
        self.protocol_version = protocol_version
        self.error_handler = error_handler
        self.server = ServerInfo(SocketAddress.from_socket(sock))
        self.input_buffer = ChunkedInputBuffer()
        self.output_buffer = ChunkedOutputBuffer()
        self.packer = Packer(self.output_buffer)
        self.unpacker = Unpacker()
        self.responses = deque()
        self._max_connection_lifetime = config.get(
            "max_connection_lifetime",
            default_config["max_connection_lifetime"])
        self._creation_timestamp = perf_counter()

        # Determine the user agent and ensure it is a Unicode value
        user_agent = config.get("user_agent", default_config["user_agent"])
        if isinstance(user_agent, bytes):
            user_agent = user_agent.decode("UTF-8")
        self.user_agent = user_agent

        # Determine auth details
        auth = config.get("auth")
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j.v1 import basic_auth
            self.auth_dict = vars(basic_auth(*auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise TypeError("Cannot determine auth details from %r" % auth)

        # Pick up the server certificate, if any
        self.der_encoded_server_certificate = config.get(
            "der_encoded_server_certificate")

    def init(self):
        response = InitResponse(self)
        self.append(INIT, (self.user_agent, self.auth_dict), response=response)
        self.sync()

        self._supports_statement_reuse = self.server.supports_statement_reuse()
        self.packer.supports_bytes = self.server.supports_bytes()

    def __del__(self):
        try:
            self.close()
        except (AttributeError, TypeError):
            pass

    def append(self, signature, fields=(), response=None):
        """ Add a message to the outgoing queue.

        :arg signature: the signature of the message
        :arg fields: the fields of the message as a tuple
        :arg response: a response object to handle callbacks
        """
        if signature == RUN:
            if self._supports_statement_reuse:
                statement = fields[0]
                if statement.upper() not in ("BEGIN", "COMMIT", "ROLLBACK"):
                    if statement == self._last_run_statement:
                        fields = ("", ) + fields[1:]
                    else:
                        self._last_run_statement = statement
            log_debug("C: RUN %r", fields)
        elif signature == PULL_ALL:
            log_debug("C: PULL_ALL %r", fields)
        elif signature == DISCARD_ALL:
            log_debug("C: DISCARD_ALL %r", fields)
        elif signature == RESET:
            log_debug("C: RESET %r", fields)
        elif signature == ACK_FAILURE:
            log_debug("C: ACK_FAILURE %r", fields)
        elif signature == INIT:
            log_debug("C: INIT (%r, {...})", fields[0])
        else:
            raise ValueError("Unknown message signature")
        self.packer.pack_struct(signature, fields)
        self.output_buffer.chunk()
        self.output_buffer.chunk()
        self.responses.append(response)

    def acknowledge_failure(self):
        """ Add an ACK_FAILURE message to the outgoing queue, send
        it and consume all remaining messages.
        """
        self.append(ACK_FAILURE, response=AckFailureResponse(self))
        self.sync()

    def reset(self):
        """ Add a RESET message to the outgoing queue, send
        it and consume all remaining messages.
        """
        self.append(RESET, response=ResetResponse(self))
        self.sync()

    def send(self):
        try:
            self._send()
        except self.error_handler.known_errors as error:
            self.error_handler.handle(error, self.address)
            raise error

    def _send(self):
        """ Send all queued messages to the server.
        """
        data = self.output_buffer.view()
        if not data:
            return
        if self.closed():
            raise self.Error(
                "Failed to write to closed connection {!r}".format(
                    self.server.address))
        if self.defunct():
            raise self.Error(
                "Failed to write to defunct connection {!r}".format(
                    self.server.address))
        self.socket.sendall(data)
        self.output_buffer.clear()

    def fetch(self):
        try:
            return self._fetch()
        except self.error_handler.known_errors as error:
            self.error_handler.handle(error, self.address)
            raise error

    def _fetch(self):
        """ Receive at least one message from the server, if available.

        :return: 2-tuple of number of detail messages and number of summary messages fetched
        """
        if self.closed():
            raise self.Error(
                "Failed to read from closed connection {!r}".format(
                    self.server.address))
        if self.defunct():
            raise self.Error(
                "Failed to read from defunct connection {!r}".format(
                    self.server.address))
        if not self.responses:
            return 0, 0

        self._receive()

        details, summary_signature, summary_metadata = self._unpack()

        if details:
            log_debug("S: RECORD * %d", len(details))  # TODO
            self.responses[0].on_records(details)

        if summary_signature is None:
            return len(details), 0

        response = self.responses.popleft()
        response.complete = True
        if summary_signature == SUCCESS:
            log_debug("S: SUCCESS (%r)", summary_metadata)
            response.on_success(summary_metadata or {})
        elif summary_signature == IGNORED:
            self._last_run_statement = None
            log_debug("S: IGNORED (%r)", summary_metadata)
            response.on_ignored(summary_metadata or {})
        elif summary_signature == FAILURE:
            self._last_run_statement = None
            log_debug("S: FAILURE (%r)", summary_metadata)
            response.on_failure(summary_metadata or {})
        else:
            self._last_run_statement = None
            raise ProtocolError(
                "Unexpected response message with signature %02X" %
                summary_signature)

        return len(details), 1

    def _receive(self):
        try:
            received = self.input_buffer.receive_message(self.socket, 8192)
        except SocketError:
            received = False
        if not received:
            self._defunct = True
            self.close()
            raise self.Error(
                "Failed to read from defunct connection {!r}".format(
                    self.server.address))

    def _unpack(self):
        unpacker = self.unpacker
        input_buffer = self.input_buffer

        details = []
        summary_signature = None
        summary_metadata = None
        more = True
        while more:
            unpacker.attach(input_buffer.frame())
            size, signature = unpacker.unpack_structure_header()
            if size > 1:
                raise ProtocolError("Expected one field")
            if signature == RECORD:
                data = unpacker.unpack_list()
                details.append(data)
                more = input_buffer.frame_message()
            else:
                summary_signature = signature
                summary_metadata = unpacker.unpack_map()
                more = False
        return details, summary_signature, summary_metadata

    def timedout(self):
        return 0 <= self._max_connection_lifetime <= perf_counter(
        ) - self._creation_timestamp

    def sync(self):
        """ Send and fetch all outstanding messages.

        :return: 2-tuple of number of detail messages and number of summary messages fetched
        """
        self.send()
        detail_count = summary_count = 0
        while self.responses:
            response = self.responses[0]
            while not response.complete:
                detail_delta, summary_delta = self.fetch()
                detail_count += detail_delta
                summary_count += summary_delta
        return detail_count, summary_count

    def close(self):
        """ Close the connection.
        """
        if not self.closed():
            log_debug("~~ [CLOSE]")
            self.socket.close()
            self._closed = True

    def closed(self):
        return self._closed

    def defunct(self):
        return self._defunct
コード例 #11
0
class Bolt3(Bolt):

    PROTOCOL_VERSION = Version(3, 0)

    # The socket
    in_use = False

    # The socket
    _closed = False

    # The socket
    _defunct = False

    #: The pool of which this connection is a member
    pool = None

    def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None):
        self.unresolved_address = unresolved_address
        self.socket = sock
        self.server_info = ServerInfo(Address(sock.getpeername()), Bolt3.PROTOCOL_VERSION)
        self.outbox = Outbox()
        self.inbox = Inbox(self.socket, on_error=self._set_defunct)
        self.packer = Packer(self.outbox)
        self.unpacker = Unpacker(self.inbox)
        self.responses = deque()
        self._max_connection_lifetime = max_connection_lifetime
        self._creation_timestamp = perf_counter()
        self.supports_multiple_results = False
        self.supports_multiple_databases = False
        self._is_reset = True

        # Determine the user agent
        if user_agent:
            self.user_agent = user_agent
        else:
            self.user_agent = get_user_agent()

        # Determine auth details
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j import Auth
            self.auth_dict = vars(Auth("basic", *auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise AuthError("Cannot determine auth details from %r" % auth)

        # Check for missing password
        try:
            credentials = self.auth_dict["credentials"]
        except KeyError:
            pass
        else:
            if credentials is None:
                raise AuthError("Password cannot be None")

    @property
    def encrypted(self):
        return isinstance(self.socket, SSLSocket)

    @property
    def der_encoded_server_certificate(self):
        return self.socket.getpeercert(binary_form=True)

    @property
    def local_port(self):
        try:
            return self.socket.getsockname()[1]
        except IOError:
            return 0

    def hello(self):
        headers = {"user_agent": self.user_agent}
        headers.update(self.auth_dict)
        logged_headers = dict(headers)
        if "credentials" in logged_headers:
            logged_headers["credentials"] = "*******"
        log.debug("[#%04X]  C: HELLO %r", self.local_port, logged_headers)
        self._append(b"\x01", (headers,),
                     response=InitResponse(self, on_success=self.server_info.metadata.update))
        self.send_all()
        self.fetch_all()

    def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
        if db is not None:
            raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db))
        if not parameters:
            parameters = {}
        extra = {}
        if mode in (READ_ACCESS, "r"):
            extra["mode"] = "r"  # It will default to mode "w" if nothing is specified
        if bookmarks:
            try:
                extra["bookmarks"] = list(bookmarks)
            except TypeError:
                raise TypeError("Bookmarks must be provided within an iterable")
        if metadata:
            try:
                extra["tx_metadata"] = dict(metadata)
            except TypeError:
                raise TypeError("Metadata must be coercible to a dict")
        if timeout:
            try:
                extra["tx_timeout"] = int(1000 * timeout)
            except TypeError:
                raise TypeError("Timeout must be specified as a number of seconds")
        fields = (query, parameters, extra)
        log.debug("[#%04X]  C: RUN %s", self.local_port, " ".join(map(repr, fields)))
        if query.upper() == u"COMMIT":
            self._append(b"\x10", fields, CommitResponse(self, **handlers))
        else:
            self._append(b"\x10", fields, Response(self, **handlers))
        self._is_reset = False

    def discard(self, n=-1, qid=-1, **handlers):
        # Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
        log.debug("[#%04X]  C: DISCARD_ALL", self.local_port)
        self._append(b"\x2F", (), Response(self, **handlers))

    def pull(self, n=-1, qid=-1, **handlers):
        # Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
        log.debug("[#%04X]  C: PULL_ALL", self.local_port)
        self._append(b"\x3F", (), Response(self, **handlers))
        self._is_reset = False

    def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
        if db is not None:
            raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db))
        extra = {}
        if mode in (READ_ACCESS, "r"):
            extra["mode"] = "r"  # It will default to mode "w" if nothing is specified
        if bookmarks:
            try:
                extra["bookmarks"] = list(bookmarks)
            except TypeError:
                raise TypeError("Bookmarks must be provided within an iterable")
        if metadata:
            try:
                extra["tx_metadata"] = dict(metadata)
            except TypeError:
                raise TypeError("Metadata must be coercible to a dict")
        if timeout:
            try:
                extra["tx_timeout"] = int(1000 * timeout)
            except TypeError:
                raise TypeError("Timeout must be specified as a number of seconds")
        log.debug("[#%04X]  C: BEGIN %r", self.local_port, extra)
        self._append(b"\x11", (extra,), Response(self, **handlers))
        self._is_reset = False

    def commit(self, **handlers):
        log.debug("[#%04X]  C: COMMIT", self.local_port)
        self._append(b"\x12", (), CommitResponse(self, **handlers))

    def rollback(self, **handlers):
        log.debug("[#%04X]  C: ROLLBACK", self.local_port)
        self._append(b"\x13", (), Response(self, **handlers))

    def _append(self, signature, fields=(), response=None):
        """ Add a message to the outgoing queue.

        :arg signature: the signature of the message
        :arg fields: the fields of the message as a tuple
        :arg response: a response object to handle callbacks
        """
        self.packer.pack_struct(signature, fields)
        self.outbox.chunk()
        self.outbox.chunk()
        self.responses.append(response)

    def reset(self):
        """ Add a RESET message to the outgoing queue, send
        it and consume all remaining messages.
        """

        def fail(metadata):
            raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)

        log.debug("[#%04X]  C: RESET", self.local_port)
        self._append(b"\x0F", response=Response(self, on_failure=fail))
        self.send_all()
        self.fetch_all()
        self._is_reset = True

    def _send_all(self):
        data = self.outbox.view()
        if data:
            self.socket.sendall(data)
            self.outbox.clear()

    def send_all(self):
        """ Send all queued messages to the server.
        """
        if self.closed():
            raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        if self.defunct():
            raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        try:
            self._send_all()
        except (IOError, OSError) as error:
            log.error("Failed to write data to connection "
                      "{!r} ({!r}); ({!r})".
                      format(self.unresolved_address,
                             self.server_info.address,
                             "; ".join(map(repr, error.args))))
            if self.pool:
                self.pool.deactivate(address=self.unresolved_address)
            raise

    def fetch_message(self):
        """ Receive at least one message from the server, if available.

        :return: 2-tuple of number of detail messages and number of summary
                 messages fetched
        """
        if self._closed:
            raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        if self._defunct:
            raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format(
                self.unresolved_address, self.server_info.address))

        if not self.responses:
            return 0, 0

        # Receive exactly one message
        try:
            details, summary_signature, summary_metadata = next(self.inbox)
        except (IOError, OSError) as error:
            log.error("Failed to read data from connection "
                      "{!r} ({!r}); ({!r})".
                      format(self.unresolved_address,
                             self.server_info.address,
                             "; ".join(map(repr, error.args))))
            if self.pool:
                self.pool.deactivate(address=self.unresolved_address)
            raise

        if details:
            log.debug("[#%04X]  S: RECORD * %d", self.local_port, len(details))  # Do not log any data
            self.responses[0].on_records(details)

        if summary_signature is None:
            return len(details), 0

        response = self.responses.popleft()
        response.complete = True
        if summary_signature == b"\x70":
            log.debug("[#%04X]  S: SUCCESS %r", self.local_port, summary_metadata)
            response.on_success(summary_metadata or {})
        elif summary_signature == b"\x7E":
            log.debug("[#%04X]  S: IGNORED", self.local_port)
            response.on_ignored(summary_metadata or {})
        elif summary_signature == b"\x7F":
            log.debug("[#%04X]  S: FAILURE %r", self.local_port, summary_metadata)
            try:
                response.on_failure(summary_metadata or {})
            except (ServiceUnavailable, DatabaseUnavailable):
                if self.pool:
                    self.pool.deactivate(address=self.unresolved_address),
                raise
            except (NotALeader, ForbiddenOnReadOnlyDatabase):
                if self.pool:
                    self.pool.on_write_failure(address=self.unresolved_address),
                raise
        else:
            raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address)

        return len(details), 1

    def _set_defunct(self, error=None):
        direct_driver = isinstance(self.pool, BoltPool)

        message = ("Failed to read from defunct connection {!r} ({!r})".format(
            self.unresolved_address, self.server_info.address))

        log.error(message)
        # We were attempting to receive data but the connection
        # has unexpectedly terminated. So, we need to close the
        # connection from the client side, and remove the address
        # from the connection pool.
        self._defunct = True
        self.close()
        if self.pool:
            self.pool.deactivate(address=self.unresolved_address)
        # Iterate through the outstanding responses, and if any correspond
        # to COMMIT requests then raise an error to signal that we are
        # unable to confirm that the COMMIT completed successfully.
        for response in self.responses:
            if isinstance(response, CommitResponse):
                raise BoltIncompleteCommitError(message, address=None)

        if direct_driver:
            raise ServiceUnavailable(message)
        else:
            raise SessionExpired(message)

    def timedout(self):
        return 0 <= self._max_connection_lifetime <= perf_counter() - self._creation_timestamp

    def fetch_all(self):
        """ Fetch all outstanding messages.

        :return: 2-tuple of number of detail messages and number of summary
                 messages fetched
        """
        detail_count = summary_count = 0
        while self.responses:
            response = self.responses[0]
            while not response.complete:
                detail_delta, summary_delta = self.fetch_message()
                detail_count += detail_delta
                summary_count += summary_delta
        return detail_count, summary_count

    def close(self):
        """ Close the connection.
        """
        if not self._closed:
            if not self._defunct:
                log.debug("[#%04X]  C: GOODBYE", self.local_port)
                self._append(b"\x02", ())
                try:
                    self._send_all()
                except:
                    pass
            log.debug("[#%04X]  C: <CLOSE>", self.local_port)
            try:
                self.socket.close()
            except IOError:
                pass
            finally:
                self._closed = True

    def closed(self):
        return self._closed

    def defunct(self):
        return self._defunct
コード例 #12
0
 def packb(cls, *values):
     stream = BytesIO()
     packer = Packer(stream)
     for value in values:
         packer.pack(value)
     return stream.getvalue()
コード例 #13
0
 def test_map_size_overflow(self):
     stream_out = BytesIO()
     packer = Packer(stream_out)
     with raises(OverflowError):
         packer.pack_map_header(2 ** 32)
コード例 #14
0
 def test_map_stream(self):
     packed_value = b"\xDB\x81A\x01\x81B\x02\xDF"
     unpacked_value = {u"A": 1, u"B": 2}
     stream_out = BytesIO()
     packer = Packer(stream_out)
     packer.pack_map_stream_header()
     packer.pack(u"A")
     packer.pack(1)
     packer.pack(u"B")
     packer.pack(2)
     packer.pack_end_of_stream()
     packed = stream_out.getvalue()
     try:
         assert packed == packed_value
     except AssertionError:
         raise AssertionError("Packed value is %r instead of expected %r" %
                              (packed, packed_value))
     unpacked = Unpacker(UnpackableBuffer(packed)).unpack()
     try:
         assert unpacked == unpacked_value
     except AssertionError:
         raise AssertionError("Unpacked value %r is not equal to expected %r" %
                              (unpacked, unpacked_value))
コード例 #15
0
class Bolt:
    """ Server connection for Bolt protocol v1.

    A :class:`.Connection` should be constructed following a
    successful Bolt handshake and takes the socket over which
    the handshake was carried out.

    .. note:: logs at INFO level
    """

    #: The protocol version in use on this connection
    protocol_version = 0

    #: Server details for this connection
    server = None

    in_use = False

    _closed = False

    _defunct = False

    #: The pool of which this connection is a member
    pool = None

    #: Error class used for raising connection errors
    # TODO: separate errors for connector API
    Error = ServiceUnavailable

    @classmethod
    def ping(cls, address, *, timeout=None, **config):
        """ Attempt to establish a Bolt connection, returning the
        agreed Bolt protocol version if successful.
        """
        config = PoolConfig.consume(config)
        try:
            s, protocol_version = connect(address,
                                          timeout=timeout,
                                          config=config)
        except ServiceUnavailable:
            return None
        else:
            s.close()
            return protocol_version

    @classmethod
    def open(cls, address, *, auth=None, timeout=None, **config):
        """ Open a new Bolt connection to a given server address.

        :param address:
        :param auth:
        :param timeout:
        :param config:
        :return:
        """
        config = PoolConfig.consume(config)
        s, config.protocol_version = connect(address,
                                             timeout=timeout,
                                             config=config)
        connection = Bolt(address, s, auth=auth, **config)
        connection.hello()
        return connection

    def __init__(self,
                 unresolved_address,
                 sock,
                 *,
                 auth=None,
                 protocol_version=None,
                 **config):
        self.config = PoolConfig.consume(config)
        self.protocol_version = protocol_version
        self.unresolved_address = unresolved_address
        self.socket = sock
        self.server = ServerInfo(Address(sock.getpeername()), protocol_version)
        self.outbox = Outbox()
        self.inbox = Inbox(BufferedSocket(self.socket, 32768),
                           on_error=self._set_defunct)
        self.packer = Packer(self.outbox)
        self.unpacker = Unpacker(self.inbox)
        self.responses = deque()
        self._max_connection_lifetime = self.config.max_age
        self._creation_timestamp = perf_counter()

        # Determine the user agent
        user_agent = self.config.user_agent
        if user_agent:
            self.user_agent = user_agent
        else:
            self.user_agent = get_user_agent()

        # Determine auth details
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j import Auth
            self.auth_dict = vars(Auth("basic", *auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise AuthError("Cannot determine auth details from %r" % auth)

        # Check for missing password
        try:
            credentials = self.auth_dict["credentials"]
        except KeyError:
            pass
        else:
            if credentials is None:
                raise AuthError("Password cannot be None")

    @property
    def secure(self):
        return isinstance(self.socket, SSLSocket)

    @property
    def der_encoded_server_certificate(self):
        return self.socket.getpeercert(binary_form=True)

    @property
    def local_port(self):
        try:
            return self.socket.getsockname()[1]
        except IOError:
            return 0

    def hello(self):
        headers = {"user_agent": self.user_agent}
        headers.update(self.auth_dict)
        logged_headers = dict(headers)
        if "credentials" in logged_headers:
            logged_headers["credentials"] = "*******"
        log.debug("[#%04X]  C: HELLO %r", self.local_port, logged_headers)
        self._append(b"\x01", (headers, ),
                     response=InitResponse(
                         self, on_success=self.server.metadata.update))
        self.send_all()
        self.fetch_all()

    def __del__(self):
        try:
            self.close()
        except:
            pass

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def run(self,
            statement,
            parameters=None,
            mode=None,
            bookmarks=None,
            metadata=None,
            timeout=None,
            **handlers):
        if not parameters:
            parameters = {}
        extra = {}
        if mode:
            extra["mode"] = mode
        if bookmarks:
            try:
                extra["bookmarks"] = list(bookmarks)
            except TypeError:
                raise TypeError(
                    "Bookmarks must be provided within an iterable")
        if metadata:
            try:
                extra["tx_metadata"] = dict(metadata)
            except TypeError:
                raise TypeError("Metadata must be coercible to a dict")
        if timeout:
            try:
                extra["tx_timeout"] = int(1000 * timeout)
            except TypeError:
                raise TypeError(
                    "Timeout must be specified as a number of seconds")
        fields = (statement, parameters, extra)
        log.debug("[#%04X]  C: RUN %s", self.local_port,
                  " ".join(map(repr, fields)))
        if statement.upper() == u"COMMIT":
            self._append(b"\x10", fields, CommitResponse(self, **handlers))
        else:
            self._append(b"\x10", fields, Response(self, **handlers))

    def discard_all(self, **handlers):
        log.debug("[#%04X]  C: DISCARD_ALL", self.local_port)
        self._append(b"\x2F", (), Response(self, **handlers))

    def pull_all(self, **handlers):
        log.debug("[#%04X]  C: PULL_ALL", self.local_port)
        self._append(b"\x3F", (), Response(self, **handlers))

    def begin(self,
              mode=None,
              bookmarks=None,
              metadata=None,
              timeout=None,
              **handlers):
        extra = {}
        if mode:
            extra["mode"] = mode
        if bookmarks:
            try:
                extra["bookmarks"] = list(bookmarks)
            except TypeError:
                raise TypeError(
                    "Bookmarks must be provided within an iterable")
        if metadata:
            try:
                extra["tx_metadata"] = dict(metadata)
            except TypeError:
                raise TypeError("Metadata must be coercible to a dict")
        if timeout:
            try:
                extra["tx_timeout"] = int(1000 * timeout)
            except TypeError:
                raise TypeError(
                    "Timeout must be specified as a number of seconds")
        log.debug("[#%04X]  C: BEGIN %r", self.local_port, extra)
        self._append(b"\x11", (extra, ), Response(self, **handlers))

    def commit(self, **handlers):
        log.debug("[#%04X]  C: COMMIT", self.local_port)
        self._append(b"\x12", (), CommitResponse(self, **handlers))

    def rollback(self, **handlers):
        log.debug("[#%04X]  C: ROLLBACK", self.local_port)
        self._append(b"\x13", (), Response(self, **handlers))

    def _append(self, signature, fields=(), response=None):
        """ Add a message to the outgoing queue.

        :arg signature: the signature of the message
        :arg fields: the fields of the message as a tuple
        :arg response: a response object to handle callbacks
        """
        self.packer.pack_struct(signature, fields)
        self.outbox.chunk()
        self.outbox.chunk()
        self.responses.append(response)

    def reset(self):
        """ Add a RESET message to the outgoing queue, send
        it and consume all remaining messages.
        """
        def fail(metadata):
            raise ProtocolError("RESET failed %r" % metadata)

        log.debug("[#%04X]  C: RESET", self.local_port)
        self._append(b"\x0F", response=Response(self, on_failure=fail))
        self.send_all()
        self.fetch_all()

    def _send_all(self):
        data = self.outbox.view()
        if data:
            self.socket.sendall(data)
            self.outbox.clear()

    def send_all(self):
        """ Send all queued messages to the server.
        """
        if self.closed():
            raise self.Error("Failed to write to closed connection "
                             "{!r} ({!r})".format(self.unresolved_address,
                                                  self.server.address))
        if self.defunct():
            raise self.Error("Failed to write to defunct connection "
                             "{!r} ({!r})".format(self.unresolved_address,
                                                  self.server.address))
        try:
            self._send_all()
        except (IOError, OSError) as error:
            log.error("Failed to write data to connection "
                      "{!r} ({!r}); ({!r})".format(
                          self.unresolved_address, self.server.address,
                          "; ".join(map(repr, error.args))))
            if self.pool:
                self.pool.deactivate(self.unresolved_address)
            raise

    def fetch_message(self):
        """ Receive at least one message from the server, if available.

        :return: 2-tuple of number of detail messages and number of summary
                 messages fetched
        """
        if self._closed:
            raise self.Error("Failed to read from closed connection "
                             "{!r} ({!r})".format(self.unresolved_address,
                                                  self.server.address))
        if self._defunct:
            raise self.Error("Failed to read from defunct connection "
                             "{!r} ({!r})".format(self.unresolved_address,
                                                  self.server.address))
        if not self.responses:
            return 0, 0

        # Receive exactly one message
        try:
            details, summary_signature, summary_metadata = next(self.inbox)
        except (IOError, OSError) as error:
            log.error("Failed to read data from connection "
                      "{!r} ({!r}); ({!r})".format(
                          self.unresolved_address, self.server.address,
                          "; ".join(map(repr, error.args))))
            if self.pool:
                self.pool.deactivate(self.unresolved_address)
            raise

        if details:
            log.debug("[#%04X]  S: RECORD * %d", self.local_port,
                      len(details))  # TODO
            self.responses[0].on_records(details)

        if summary_signature is None:
            return len(details), 0

        response = self.responses.popleft()
        response.complete = True
        if summary_signature == b"\x70":
            log.debug("[#%04X]  S: SUCCESS %r", self.local_port,
                      summary_metadata)
            response.on_success(summary_metadata or {})
        elif summary_signature == b"\x7E":
            log.debug("[#%04X]  S: IGNORED", self.local_port)
            response.on_ignored(summary_metadata or {})
        elif summary_signature == b"\x7F":
            log.debug("[#%04X]  S: FAILURE %r", self.local_port,
                      summary_metadata)
            try:
                response.on_failure(summary_metadata or {})
            except (ConnectionExpired, ServiceUnavailable,
                    DatabaseUnavailableError):
                if self.pool:
                    self.pool.deactivate(self.unresolved_address),
                raise
            except (NotALeaderError, ForbiddenOnReadOnlyDatabaseError):
                if self.pool:
                    self.pool.on_write_failure(self.unresolved_address),
                raise
        else:
            raise ProtocolError("Unexpected response message with "
                                "signature %02X" % summary_signature)

        return len(details), 1

    def _set_defunct(self, error=None):
        message = ("Failed to read from defunct connection "
                   "{!r} ({!r})".format(self.unresolved_address,
                                        self.server.address))
        log.error(message)
        # We were attempting to receive data but the connection
        # has unexpectedly terminated. So, we need to close the
        # connection from the client side, and remove the address
        # from the connection pool.
        self._defunct = True
        self.close()
        if self.pool:
            self.pool.deactivate(self.unresolved_address)
        # Iterate through the outstanding responses, and if any correspond
        # to COMMIT requests then raise an error to signal that we are
        # unable to confirm that the COMMIT completed successfully.
        for response in self.responses:
            if isinstance(response, CommitResponse):
                raise IncompleteCommitError(message)
        raise self.Error(message)

    def timedout(self):
        return 0 <= self._max_connection_lifetime <= perf_counter(
        ) - self._creation_timestamp

    def fetch_all(self):
        """ Fetch all outstanding messages.

        :return: 2-tuple of number of detail messages and number of summary
                 messages fetched
        """
        detail_count = summary_count = 0
        while self.responses:
            response = self.responses[0]
            while not response.complete:
                detail_delta, summary_delta = self.fetch_message()
                detail_count += detail_delta
                summary_count += summary_delta
        return detail_count, summary_count

    def close(self):
        """ Close the connection.
        """
        if not self._closed:
            if not self._defunct:
                log.debug("[#%04X]  C: GOODBYE", self.local_port)
                self._append(b"\x02", ())
                try:
                    self._send_all()
                except:
                    pass
            log.debug("[#%04X]  C: <CLOSE>", self.local_port)
            try:
                self.socket.close()
            except IOError:
                pass
            finally:
                self._closed = True

    def closed(self):
        return self._closed

    def defunct(self):
        return self._defunct
コード例 #16
0
 def encode_message(cls, tag, *fields):
     b = BytesIO()
     packer = Packer(b)
     for field in fields:
         packer.pack(field)
     return bytearray([0xB0 + len(fields), tag]) + b.getvalue()