def __init__(self, address, sock, protocol_version, error_handler, **config): self.address = address self.socket = sock self.protocol_version = protocol_version self.error_handler = error_handler self.server = ServerInfo(SocketAddress.from_socket(sock)) self.input_buffer = ChunkedInputBuffer() self.output_buffer = ChunkedOutputBuffer() self.packer = Packer(self.output_buffer) self.unpacker = Unpacker() self.responses = deque() self._max_connection_lifetime = config.get("max_connection_lifetime", default_config["max_connection_lifetime"]) self._creation_timestamp = perf_counter() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", default_config["user_agent"]) if isinstance(user_agent, bytes): user_agent = user_agent.decode("UTF-8") self.user_agent = user_agent # Determine auth details auth = config.get("auth") if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j.v1 import basic_auth self.auth_dict = vars(basic_auth(*auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise TypeError("Cannot determine auth details from %r" % auth) # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get("der_encoded_server_certificate")
def test_dehydration_2d(self): coordinates = (.1, 0) p = CartesianPoint(coordinates) dehydrator = DataDehydrator() buffer = io.BytesIO() packer = Packer(buffer) packer.pack(dehydrator.dehydrate((p, ))[0]) self.assertEqual( buffer.getvalue(), b"\xB3X" + b"\xC9" + struct.pack(">h", 7203) + b"".join( map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)))
def test_dehydration_3d(self): coordinates = (1, -2, 3.1) p = WGS84Point(coordinates) dehydrator = DataDehydrator() buffer = io.BytesIO() packer = Packer(buffer) packer.pack(dehydrator.dehydrate((p, ))[0]) self.assertEqual( buffer.getvalue(), b"\xB4Y" + b"\xC9" + struct.pack(">h", 4979) + b"".join( map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)))
def test_dehydration(self): MyPoint = point_type("MyPoint", ["x", "y"], {2: 1234}) coordinates = (.1, 0) p = MyPoint(coordinates) dehydrator = DataDehydrator() buffer = io.BytesIO() packer = Packer(buffer) packer.pack(dehydrator.dehydrate((p, ))[0]) self.assertEqual( buffer.getvalue(), b"\xB3X" + b"\xC9" + struct.pack(">h", 1234) + b"".join( map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)))
def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None): self.unresolved_address = unresolved_address self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION) self.outbox = Outbox() self.inbox = Inbox(self.socket, on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = max_connection_lifetime self._creation_timestamp = perf_counter() self.supports_multiple_results = False self.supports_multiple_databases = False self._is_reset = True self.routing_context = routing_context # Determine the user agent if user_agent: self.user_agent = user_agent else: self.user_agent = get_user_agent() # Determine auth details if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j import Auth self.auth_dict = vars(Auth("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None")
def assert_packable(cls, value, packed_value): stream_out = BytesIO() packer = Packer(stream_out) packer.pack(value) packed = stream_out.getvalue() try: assert packed == packed_value except AssertionError: raise AssertionError("Packed value %r is %r instead of expected %r" % (value, packed, packed_value)) unpacked = Unpacker(UnpackableBuffer(packed)).unpack() try: assert unpacked == value except AssertionError: raise AssertionError("Unpacked value %r is not equal to original %r" % (unpacked, value))
def __init__(self, unresolved_address, sock, *, auth=None, protocol_version=None, **config): self.config = PoolConfig.consume(config) self.protocol_version = protocol_version self.unresolved_address = unresolved_address self.socket = sock self.server = ServerInfo(Address(sock.getpeername()), protocol_version) self.outbox = Outbox() self.inbox = Inbox(BufferedSocket(self.socket, 32768), on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = self.config.max_age self._creation_timestamp = perf_counter() # Determine the user agent user_agent = self.config.user_agent if user_agent: self.user_agent = user_agent else: self.user_agent = get_user_agent() # Determine auth details if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j import Auth self.auth_dict = vars(Auth("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None")
def test_list_stream(self): packed_value = b"\xD7\x01\x02\x03\xDF" unpacked_value = [1, 2, 3] stream_out = BytesIO() packer = Packer(stream_out) packer.pack_list_stream_header() packer.pack(1) packer.pack(2) packer.pack(3) packer.pack_end_of_stream() packed = stream_out.getvalue() try: assert packed == packed_value except AssertionError: raise AssertionError("Packed value is %r instead of expected %r" % (packed, packed_value)) unpacked = Unpacker(UnpackableBuffer(packed)).unpack() try: assert unpacked == unpacked_value except AssertionError: raise AssertionError("Unpacked value %r is not equal to expected %r" % (unpacked, unpacked_value))
def __init__(self, sock, **config): self.socket = sock self.server = ServerInfo(SocketAddress.from_socket(sock)) self.input_buffer = ChunkedInputBuffer() self.output_buffer = ChunkedOutputBuffer() self.packer = Packer(self.output_buffer) self.unpacker = Unpacker() self.responses = deque() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", DEFAULT_USER_AGENT) if isinstance(user_agent, bytes): user_agent = user_agent.decode("UTF-8") self.user_agent = user_agent # Determine auth details auth = config.get("auth") if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j.v1 import basic_auth self.auth_dict = vars(basic_auth(*auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise TypeError("Cannot determine auth details from %r" % auth) # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get( "der_encoded_server_certificate") response = InitResponse(self) self.append(INIT, (self.user_agent, self.auth_dict), response=response) self.sync() self._supports_statement_reuse = self.server.supports_statement_reuse() self.packer.supports_bytes = self.server.supports_bytes()
class Connection(object): """ Server connection for Bolt protocol v1. A :class:`.Connection` should be constructed following a successful Bolt handshake and takes the socket over which the handshake was carried out. .. note:: logs at INFO level """ #: The protocol version in use on this connection protocol_version = 0 #: Server details for this connection server = None in_use = False _closed = False _defunct = False #: The pool of which this connection is a member pool = None #: Error class used for raising connection errors Error = ServiceUnavailable _supports_statement_reuse = False _last_run_statement = None def __init__(self, address, sock, protocol_version, error_handler, **config): self.address = address self.socket = sock self.protocol_version = protocol_version self.error_handler = error_handler self.server = ServerInfo(SocketAddress.from_socket(sock)) self.input_buffer = ChunkedInputBuffer() self.output_buffer = ChunkedOutputBuffer() self.packer = Packer(self.output_buffer) self.unpacker = Unpacker() self.responses = deque() self._max_connection_lifetime = config.get( "max_connection_lifetime", default_config["max_connection_lifetime"]) self._creation_timestamp = perf_counter() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", default_config["user_agent"]) if isinstance(user_agent, bytes): user_agent = user_agent.decode("UTF-8") self.user_agent = user_agent # Determine auth details auth = config.get("auth") if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j.v1 import basic_auth self.auth_dict = vars(basic_auth(*auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise TypeError("Cannot determine auth details from %r" % auth) # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get( "der_encoded_server_certificate") def init(self): response = InitResponse(self) self.append(INIT, (self.user_agent, self.auth_dict), response=response) self.sync() self._supports_statement_reuse = self.server.supports_statement_reuse() self.packer.supports_bytes = self.server.supports_bytes() def __del__(self): try: self.close() except (AttributeError, TypeError): pass def append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. :arg signature: the signature of the message :arg fields: the fields of the message as a tuple :arg response: a response object to handle callbacks """ if signature == RUN: if self._supports_statement_reuse: statement = fields[0] if statement.upper() not in ("BEGIN", "COMMIT", "ROLLBACK"): if statement == self._last_run_statement: fields = ("", ) + fields[1:] else: self._last_run_statement = statement log_debug("C: RUN %r", fields) elif signature == PULL_ALL: log_debug("C: PULL_ALL %r", fields) elif signature == DISCARD_ALL: log_debug("C: DISCARD_ALL %r", fields) elif signature == RESET: log_debug("C: RESET %r", fields) elif signature == ACK_FAILURE: log_debug("C: ACK_FAILURE %r", fields) elif signature == INIT: log_debug("C: INIT (%r, {...})", fields[0]) else: raise ValueError("Unknown message signature") self.packer.pack_struct(signature, fields) self.output_buffer.chunk() self.output_buffer.chunk() self.responses.append(response) def acknowledge_failure(self): """ Add an ACK_FAILURE message to the outgoing queue, send it and consume all remaining messages. """ self.append(ACK_FAILURE, response=AckFailureResponse(self)) self.sync() def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ self.append(RESET, response=ResetResponse(self)) self.sync() def send(self): try: self._send() except self.error_handler.known_errors as error: self.error_handler.handle(error, self.address) raise error def _send(self): """ Send all queued messages to the server. """ data = self.output_buffer.view() if not data: return if self.closed(): raise self.Error( "Failed to write to closed connection {!r}".format( self.server.address)) if self.defunct(): raise self.Error( "Failed to write to defunct connection {!r}".format( self.server.address)) self.socket.sendall(data) self.output_buffer.clear() def fetch(self): try: return self._fetch() except self.error_handler.known_errors as error: self.error_handler.handle(error, self.address) raise error def _fetch(self): """ Receive at least one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ if self.closed(): raise self.Error( "Failed to read from closed connection {!r}".format( self.server.address)) if self.defunct(): raise self.Error( "Failed to read from defunct connection {!r}".format( self.server.address)) if not self.responses: return 0, 0 self._receive() details, summary_signature, summary_metadata = self._unpack() if details: log_debug("S: RECORD * %d", len(details)) # TODO self.responses[0].on_records(details) if summary_signature is None: return len(details), 0 response = self.responses.popleft() response.complete = True if summary_signature == SUCCESS: log_debug("S: SUCCESS (%r)", summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == IGNORED: self._last_run_statement = None log_debug("S: IGNORED (%r)", summary_metadata) response.on_ignored(summary_metadata or {}) elif summary_signature == FAILURE: self._last_run_statement = None log_debug("S: FAILURE (%r)", summary_metadata) response.on_failure(summary_metadata or {}) else: self._last_run_statement = None raise ProtocolError( "Unexpected response message with signature %02X" % summary_signature) return len(details), 1 def _receive(self): try: received = self.input_buffer.receive_message(self.socket, 8192) except SocketError: received = False if not received: self._defunct = True self.close() raise self.Error( "Failed to read from defunct connection {!r}".format( self.server.address)) def _unpack(self): unpacker = self.unpacker input_buffer = self.input_buffer details = [] summary_signature = None summary_metadata = None more = True while more: unpacker.attach(input_buffer.frame()) size, signature = unpacker.unpack_structure_header() if size > 1: raise ProtocolError("Expected one field") if signature == RECORD: data = unpacker.unpack_list() details.append(data) more = input_buffer.frame_message() else: summary_signature = signature summary_metadata = unpacker.unpack_map() more = False return details, summary_signature, summary_metadata def timedout(self): return 0 <= self._max_connection_lifetime <= perf_counter( ) - self._creation_timestamp def sync(self): """ Send and fetch all outstanding messages. :return: 2-tuple of number of detail messages and number of summary messages fetched """ self.send() detail_count = summary_count = 0 while self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = self.fetch() detail_count += detail_delta summary_count += summary_delta return detail_count, summary_count def close(self): """ Close the connection. """ if not self.closed(): log_debug("~~ [CLOSE]") self.socket.close() self._closed = True def closed(self): return self._closed def defunct(self): return self._defunct
class Bolt3(Bolt): PROTOCOL_VERSION = Version(3, 0) # The socket in_use = False # The socket _closed = False # The socket _defunct = False #: The pool of which this connection is a member pool = None def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None): self.unresolved_address = unresolved_address self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), Bolt3.PROTOCOL_VERSION) self.outbox = Outbox() self.inbox = Inbox(self.socket, on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = max_connection_lifetime self._creation_timestamp = perf_counter() self.supports_multiple_results = False self.supports_multiple_databases = False self._is_reset = True # Determine the user agent if user_agent: self.user_agent = user_agent else: self.user_agent = get_user_agent() # Determine auth details if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j import Auth self.auth_dict = vars(Auth("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None") @property def encrypted(self): return isinstance(self.socket, SSLSocket) @property def der_encoded_server_certificate(self): return self.socket.getpeercert(binary_form=True) @property def local_port(self): try: return self.socket.getsockname()[1] except IOError: return 0 def hello(self): headers = {"user_agent": self.user_agent} headers.update(self.auth_dict) logged_headers = dict(headers) if "credentials" in logged_headers: logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), response=InitResponse(self, on_success=self.server_info.metadata.update)) self.send_all() self.fetch_all() def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): if db is not None: raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db)) if not parameters: parameters = {} extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError("Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError("Timeout must be specified as a number of seconds") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) if query.upper() == u"COMMIT": self._append(b"\x10", fields, CommitResponse(self, **handlers)) else: self._append(b"\x10", fields, Response(self, **handlers)) self._is_reset = False def discard(self, n=-1, qid=-1, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: DISCARD_ALL", self.local_port) self._append(b"\x2F", (), Response(self, **handlers)) def pull(self, n=-1, qid=-1, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append(b"\x3F", (), Response(self, **handlers)) self._is_reset = False def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): if db is not None: raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db)) extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError("Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError("Timeout must be specified as a number of seconds") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra,), Response(self, **handlers)) self._is_reset = False def commit(self, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) self._append(b"\x12", (), CommitResponse(self, **handlers)) def rollback(self, **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append(b"\x13", (), Response(self, **handlers)) def _append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. :arg signature: the signature of the message :arg fields: the fields of the message as a tuple :arg response: a response object to handle callbacks """ self.packer.pack_struct(signature, fields) self.outbox.chunk() self.outbox.chunk() self.responses.append(response) def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) self._append(b"\x0F", response=Response(self, on_failure=fail)) self.send_all() self.fetch_all() self._is_reset = True def _send_all(self): data = self.outbox.view() if data: self.socket.sendall(data) self.outbox.clear() def send_all(self): """ Send all queued messages to the server. """ if self.closed(): raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) if self.defunct(): raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) try: self._send_all() except (IOError, OSError) as error: log.error("Failed to write data to connection " "{!r} ({!r}); ({!r})". format(self.unresolved_address, self.server_info.address, "; ".join(map(repr, error.args)))) if self.pool: self.pool.deactivate(address=self.unresolved_address) raise def fetch_message(self): """ Receive at least one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ if self._closed: raise ServiceUnavailable("Failed to read from closed connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) if self._defunct: raise ServiceUnavailable("Failed to read from defunct connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) if not self.responses: return 0, 0 # Receive exactly one message try: details, summary_signature, summary_metadata = next(self.inbox) except (IOError, OSError) as error: log.error("Failed to read data from connection " "{!r} ({!r}); ({!r})". format(self.unresolved_address, self.server_info.address, "; ".join(map(repr, error.args)))) if self.pool: self.pool.deactivate(address=self.unresolved_address) raise if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data self.responses[0].on_records(details) if summary_signature is None: return len(details), 0 response = self.responses.popleft() response.complete = True if summary_signature == b"\x70": log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": log.debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): if self.pool: self.pool.deactivate(address=self.unresolved_address), raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure(address=self.unresolved_address), raise else: raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) return len(details), 1 def _set_defunct(self, error=None): direct_driver = isinstance(self.pool, BoltPool) message = ("Failed to read from defunct connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) log.error(message) # We were attempting to receive data but the connection # has unexpectedly terminated. So, we need to close the # connection from the client side, and remove the address # from the connection pool. self._defunct = True self.close() if self.pool: self.pool.deactivate(address=self.unresolved_address) # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are # unable to confirm that the COMMIT completed successfully. for response in self.responses: if isinstance(response, CommitResponse): raise BoltIncompleteCommitError(message, address=None) if direct_driver: raise ServiceUnavailable(message) else: raise SessionExpired(message) def timedout(self): return 0 <= self._max_connection_lifetime <= perf_counter() - self._creation_timestamp def fetch_all(self): """ Fetch all outstanding messages. :return: 2-tuple of number of detail messages and number of summary messages fetched """ detail_count = summary_count = 0 while self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = self.fetch_message() detail_count += detail_delta summary_count += summary_delta return detail_count, summary_count def close(self): """ Close the connection. """ if not self._closed: if not self._defunct: log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", ()) try: self._send_all() except: pass log.debug("[#%04X] C: <CLOSE>", self.local_port) try: self.socket.close() except IOError: pass finally: self._closed = True def closed(self): return self._closed def defunct(self): return self._defunct
def packb(cls, *values): stream = BytesIO() packer = Packer(stream) for value in values: packer.pack(value) return stream.getvalue()
def test_map_size_overflow(self): stream_out = BytesIO() packer = Packer(stream_out) with raises(OverflowError): packer.pack_map_header(2 ** 32)
def test_map_stream(self): packed_value = b"\xDB\x81A\x01\x81B\x02\xDF" unpacked_value = {u"A": 1, u"B": 2} stream_out = BytesIO() packer = Packer(stream_out) packer.pack_map_stream_header() packer.pack(u"A") packer.pack(1) packer.pack(u"B") packer.pack(2) packer.pack_end_of_stream() packed = stream_out.getvalue() try: assert packed == packed_value except AssertionError: raise AssertionError("Packed value is %r instead of expected %r" % (packed, packed_value)) unpacked = Unpacker(UnpackableBuffer(packed)).unpack() try: assert unpacked == unpacked_value except AssertionError: raise AssertionError("Unpacked value %r is not equal to expected %r" % (unpacked, unpacked_value))
class Bolt: """ Server connection for Bolt protocol v1. A :class:`.Connection` should be constructed following a successful Bolt handshake and takes the socket over which the handshake was carried out. .. note:: logs at INFO level """ #: The protocol version in use on this connection protocol_version = 0 #: Server details for this connection server = None in_use = False _closed = False _defunct = False #: The pool of which this connection is a member pool = None #: Error class used for raising connection errors # TODO: separate errors for connector API Error = ServiceUnavailable @classmethod def ping(cls, address, *, timeout=None, **config): """ Attempt to establish a Bolt connection, returning the agreed Bolt protocol version if successful. """ config = PoolConfig.consume(config) try: s, protocol_version = connect(address, timeout=timeout, config=config) except ServiceUnavailable: return None else: s.close() return protocol_version @classmethod def open(cls, address, *, auth=None, timeout=None, **config): """ Open a new Bolt connection to a given server address. :param address: :param auth: :param timeout: :param config: :return: """ config = PoolConfig.consume(config) s, config.protocol_version = connect(address, timeout=timeout, config=config) connection = Bolt(address, s, auth=auth, **config) connection.hello() return connection def __init__(self, unresolved_address, sock, *, auth=None, protocol_version=None, **config): self.config = PoolConfig.consume(config) self.protocol_version = protocol_version self.unresolved_address = unresolved_address self.socket = sock self.server = ServerInfo(Address(sock.getpeername()), protocol_version) self.outbox = Outbox() self.inbox = Inbox(BufferedSocket(self.socket, 32768), on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = self.config.max_age self._creation_timestamp = perf_counter() # Determine the user agent user_agent = self.config.user_agent if user_agent: self.user_agent = user_agent else: self.user_agent = get_user_agent() # Determine auth details if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j import Auth self.auth_dict = vars(Auth("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None") @property def secure(self): return isinstance(self.socket, SSLSocket) @property def der_encoded_server_certificate(self): return self.socket.getpeercert(binary_form=True) @property def local_port(self): try: return self.socket.getsockname()[1] except IOError: return 0 def hello(self): headers = {"user_agent": self.user_agent} headers.update(self.auth_dict) logged_headers = dict(headers) if "credentials" in logged_headers: logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers, ), response=InitResponse( self, on_success=self.server.metadata.update)) self.send_all() self.fetch_all() def __del__(self): try: self.close() except: pass def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def run(self, statement, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, **handlers): if not parameters: parameters = {} extra = {} if mode: extra["mode"] = mode if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError( "Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError( "Timeout must be specified as a number of seconds") fields = (statement, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) if statement.upper() == u"COMMIT": self._append(b"\x10", fields, CommitResponse(self, **handlers)) else: self._append(b"\x10", fields, Response(self, **handlers)) def discard_all(self, **handlers): log.debug("[#%04X] C: DISCARD_ALL", self.local_port) self._append(b"\x2F", (), Response(self, **handlers)) def pull_all(self, **handlers): log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append(b"\x3F", (), Response(self, **handlers)) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, **handlers): extra = {} if mode: extra["mode"] = mode if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError( "Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError( "Timeout must be specified as a number of seconds") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra, ), Response(self, **handlers)) def commit(self, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) self._append(b"\x12", (), CommitResponse(self, **handlers)) def rollback(self, **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append(b"\x13", (), Response(self, **handlers)) def _append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. :arg signature: the signature of the message :arg fields: the fields of the message as a tuple :arg response: a response object to handle callbacks """ self.packer.pack_struct(signature, fields) self.outbox.chunk() self.outbox.chunk() self.responses.append(response) def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ def fail(metadata): raise ProtocolError("RESET failed %r" % metadata) log.debug("[#%04X] C: RESET", self.local_port) self._append(b"\x0F", response=Response(self, on_failure=fail)) self.send_all() self.fetch_all() def _send_all(self): data = self.outbox.view() if data: self.socket.sendall(data) self.outbox.clear() def send_all(self): """ Send all queued messages to the server. """ if self.closed(): raise self.Error("Failed to write to closed connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) if self.defunct(): raise self.Error("Failed to write to defunct connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) try: self._send_all() except (IOError, OSError) as error: log.error("Failed to write data to connection " "{!r} ({!r}); ({!r})".format( self.unresolved_address, self.server.address, "; ".join(map(repr, error.args)))) if self.pool: self.pool.deactivate(self.unresolved_address) raise def fetch_message(self): """ Receive at least one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ if self._closed: raise self.Error("Failed to read from closed connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) if self._defunct: raise self.Error("Failed to read from defunct connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) if not self.responses: return 0, 0 # Receive exactly one message try: details, summary_signature, summary_metadata = next(self.inbox) except (IOError, OSError) as error: log.error("Failed to read data from connection " "{!r} ({!r}); ({!r})".format( self.unresolved_address, self.server.address, "; ".join(map(repr, error.args)))) if self.pool: self.pool.deactivate(self.unresolved_address) raise if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # TODO self.responses[0].on_records(details) if summary_signature is None: return len(details), 0 response = self.responses.popleft() response.complete = True if summary_signature == b"\x70": log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": log.debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) try: response.on_failure(summary_metadata or {}) except (ConnectionExpired, ServiceUnavailable, DatabaseUnavailableError): if self.pool: self.pool.deactivate(self.unresolved_address), raise except (NotALeaderError, ForbiddenOnReadOnlyDatabaseError): if self.pool: self.pool.on_write_failure(self.unresolved_address), raise else: raise ProtocolError("Unexpected response message with " "signature %02X" % summary_signature) return len(details), 1 def _set_defunct(self, error=None): message = ("Failed to read from defunct connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) log.error(message) # We were attempting to receive data but the connection # has unexpectedly terminated. So, we need to close the # connection from the client side, and remove the address # from the connection pool. self._defunct = True self.close() if self.pool: self.pool.deactivate(self.unresolved_address) # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are # unable to confirm that the COMMIT completed successfully. for response in self.responses: if isinstance(response, CommitResponse): raise IncompleteCommitError(message) raise self.Error(message) def timedout(self): return 0 <= self._max_connection_lifetime <= perf_counter( ) - self._creation_timestamp def fetch_all(self): """ Fetch all outstanding messages. :return: 2-tuple of number of detail messages and number of summary messages fetched """ detail_count = summary_count = 0 while self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = self.fetch_message() detail_count += detail_delta summary_count += summary_delta return detail_count, summary_count def close(self): """ Close the connection. """ if not self._closed: if not self._defunct: log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", ()) try: self._send_all() except: pass log.debug("[#%04X] C: <CLOSE>", self.local_port) try: self.socket.close() except IOError: pass finally: self._closed = True def closed(self): return self._closed def defunct(self): return self._defunct
def encode_message(cls, tag, *fields): b = BytesIO() packer = Packer(b) for field in fields: packer.pack(field) return bytearray([0xB0 + len(fields), tag]) + b.getvalue()