def __init__(self, protocol_version, address, sock, **config): self.protocol_version = protocol_version self.address = address self.socket = sock self.error_handler = config.get("error_handler", ConnectionErrorHandler()) self.server = ServerInfo(SocketAddress.from_socket(sock), protocol_version) self.input_buffer = ChunkedInputBuffer() self.output_buffer = ChunkedOutputBuffer() self.packer = Packer(self.output_buffer) self.unpacker = Unpacker() self.responses = deque() self._max_connection_lifetime = config.get("max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME) self._creation_timestamp = perf_counter() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", DEFAULT_USER_AGENT) if isinstance(user_agent, bytes): user_agent = user_agent.decode("UTF-8") self.user_agent = user_agent # Determine auth details auth = config.get("auth") if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: self.auth_dict = vars(AuthToken("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise TypeError("Cannot determine auth details from %r" % auth) # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get("der_encoded_server_certificate")
def __init__(self, protocol_version, unresolved_address, sock, **config): self.protocol_version = protocol_version self.unresolved_address = unresolved_address self.socket = sock self.server = ServerInfo(SocketAddress.from_socket(sock), protocol_version) self.outbox = Outbox() self.inbox = Inbox(BufferedSocket(self.socket, 32768), on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = config.get( "max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME) self._creation_timestamp = perf_counter() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", get_user_agent()) if isinstance(user_agent, bytes): user_agent = user_agent.decode("UTF-8") self.user_agent = user_agent # Determine auth details auth = config.get("auth") if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: self.auth_dict = vars(AuthToken("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None") # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get( "der_encoded_server_certificate")
class Connection(object): """ Server connection for Bolt protocol v1. A :class:`.Connection` should be constructed following a successful Bolt handshake and takes the socket over which the handshake was carried out. .. note:: logs at INFO level """ #: The protocol version in use on this connection protocol_version = 0 #: Server details for this connection server = None in_use = False _closed = False _defunct = False #: The pool of which this connection is a member pool = None #: Error class used for raising connection errors Error = ServiceUnavailable _last_run_statement = None def __init__(self, protocol_version, address, sock, **config): self.protocol_version = protocol_version self.address = address self.socket = sock self.error_handler = config.get("error_handler", ConnectionErrorHandler()) self.server = ServerInfo(SocketAddress.from_socket(sock), protocol_version) self.input_buffer = ChunkedInputBuffer() self.output_buffer = ChunkedOutputBuffer() self.packer = Packer(self.output_buffer) self.unpacker = Unpacker() self.responses = deque() self._max_connection_lifetime = config.get( "max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME) self._creation_timestamp = perf_counter() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", DEFAULT_USER_AGENT) if isinstance(user_agent, bytes): user_agent = user_agent.decode("UTF-8") self.user_agent = user_agent # Determine auth details auth = config.get("auth") if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: self.auth_dict = vars(AuthToken("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None") # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get( "der_encoded_server_certificate") @property def secure(self): return SSL_AVAILABLE and isinstance(self.socket, SSLSocket) @property def local_port(self): try: return self.socket.getsockname()[1] except IOError: return 0 def init(self): log_debug("[#%04X] C: INIT %r {...}", self.local_port, self.user_agent) self._append(b"\x01", (self.user_agent, self.auth_dict), response=InitResponse( self, on_success=self.server.metadata.update)) self.sync() self.packer.supports_bytes = self.server.supports("bytes") def hello(self): headers = {"user_agent": self.user_agent} headers.update(self.auth_dict) logged_headers = dict(headers) if "credentials" in logged_headers: logged_headers["credentials"] = "*******" log_debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers, ), response=InitResponse( self, on_success=self.server.metadata.update)) self.sync() self.packer.supports_bytes = self.server.supports("bytes") def __del__(self): try: self.close() except: pass def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def run(self, statement, parameters=None, bookmarks=None, metadata=None, timeout=None, **handlers): if self.server.supports("statement_reuse"): if statement.upper() not in (u"BEGIN", u"COMMIT", u"ROLLBACK"): if statement == self._last_run_statement: statement = "" else: self._last_run_statement = statement if not parameters: parameters = {} if self.protocol_version >= 3: extra = {} if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError( "Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError( "Timeout must be specified as a number of seconds") fields = (statement, parameters, extra) else: if metadata: raise NotImplementedError( "Transaction metadata is not supported in Bolt v%d" % self.protocol_version) if timeout: raise NotImplementedError( "Transaction timeouts are not supported in Bolt v%d" % self.protocol_version) fields = (statement, parameters) log_debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) self._append(b"\x10", fields, Response(self, **handlers)) def discard_all(self, **handlers): log_debug("[#%04X] C: DISCARD_ALL", self.local_port) self._append(b"\x2F", (), Response(self, **handlers)) def pull_all(self, **handlers): log_debug("[#%04X] C: PULL_ALL", self.local_port) self._append(b"\x3F", (), Response(self, **handlers)) def begin(self, bookmarks=None, metadata=None, timeout=None, **handlers): if self.protocol_version >= 3: extra = {} if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError( "Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError( "Timeout must be specified as a number of seconds") log_debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra, ), Response(self, **handlers)) else: extra = {} if bookmarks: if self.protocol_version < 2: # TODO 2.0: remove extra["bookmark"] = last_bookmark(bookmarks) try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError( "Bookmarks must be provided within an iterable") if metadata: raise NotImplementedError( "Transaction metadata is not supported in Bolt v%d" % self.protocol_version) if timeout: raise NotImplementedError( "Transaction timeouts are not supported in Bolt v%d" % self.protocol_version) self.run(u"BEGIN", extra, **handlers) self.discard_all(**handlers) def commit(self, **handlers): if self.protocol_version >= 3: log_debug("[#%04X] C: COMMIT", self.local_port) self._append(b"\x12", (), Response(self, **handlers)) else: self.run(u"COMMIT", {}, **handlers) self.discard_all(**handlers) def rollback(self, **handlers): if self.protocol_version >= 3: log_debug("[#%04X] C: ROLLBACK", self.local_port) self._append(b"\x13", (), Response(self, **handlers)) else: self.run(u"ROLLBACK", {}, **handlers) self.discard_all(**handlers) def _append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. :arg signature: the signature of the message :arg fields: the fields of the message as a tuple :arg response: a response object to handle callbacks """ self.packer.pack_struct(signature, fields) self.output_buffer.chunk() self.output_buffer.chunk() self.responses.append(response) def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ def fail(metadata): raise ProtocolError("RESET failed %r" % metadata) log_debug("[#%04X] C: RESET", self.local_port) self._append(b"\x0F", response=Response(self, on_failure=fail)) self.sync() def send(self): try: self._send() except self.error_handler.known_errors as error: self.error_handler.handle(error, self.address) raise error def _send(self): """ Send all queued messages to the server. """ data = self.output_buffer.view() if not data: return if self.closed(): raise self.Error( "Failed to write to closed connection {!r}".format( self.server.address)) if self.defunct(): raise self.Error( "Failed to write to defunct connection {!r}".format( self.server.address)) self.socket.sendall(data) self.output_buffer.clear() def fetch(self): try: return self._fetch() except self.error_handler.known_errors as error: self.error_handler.handle(error, self.address) raise error def _fetch(self): """ Receive at least one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ if self.closed(): raise self.Error( "Failed to read from closed connection {!r}".format( self.server.address)) if self.defunct(): raise self.Error( "Failed to read from defunct connection {!r}".format( self.server.address)) if not self.responses: return 0, 0 self._receive() details, summary_signature, summary_metadata = self._unpack() if details: log_debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # TODO self.responses[0].on_records(details) if summary_signature is None: return len(details), 0 response = self.responses.popleft() response.complete = True if summary_signature == b"\x70": log_debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": self._last_run_statement = None log_debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": self._last_run_statement = None log_debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) response.on_failure(summary_metadata or {}) else: self._last_run_statement = None raise ProtocolError( "Unexpected response message with signature %02X" % summary_signature) return len(details), 1 def _receive(self): try: received = self.input_buffer.receive_message(self.socket, 8192) except SocketError: received = 0 else: if received == -1: raise KeyboardInterrupt() if not received: self._defunct = True self.close() raise self.Error( "Failed to read from defunct connection {!r}".format( self.server.address)) def _unpack(self): unpacker = self.unpacker input_buffer = self.input_buffer details = [] summary_signature = None summary_metadata = None more = True while more: unpacker.attach(input_buffer.frame()) size, signature = unpacker.unpack_structure_header() if size > 1: raise ProtocolError("Expected one field") if signature == b"\x71": data = unpacker.unpack_list() details.append(data) more = input_buffer.frame_message() else: summary_signature = signature summary_metadata = unpacker.unpack_map() more = False return details, summary_signature, summary_metadata def timedout(self): return 0 <= self._max_connection_lifetime <= perf_counter( ) - self._creation_timestamp def sync(self): """ Send and fetch all outstanding messages. :return: 2-tuple of number of detail messages and number of summary messages fetched """ self.send() detail_count = summary_count = 0 while self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = self.fetch() detail_count += detail_delta summary_count += summary_delta return detail_count, summary_count def close(self): """ Close the connection. """ if not self._closed: if self.protocol_version >= 3: log_debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", ()) try: self.send() except ServiceUnavailable: pass log_debug("[#%04X] C: <CLOSE>", self.local_port) try: self.socket.close() except IOError: pass finally: self._closed = True def closed(self): return self._closed def defunct(self): return self._defunct
class Connection(object): """ Server connection for Bolt protocol v1. A :class:`.Connection` should be constructed following a successful Bolt handshake and takes the socket over which the handshake was carried out. .. note:: logs at INFO level """ #: The protocol version in use on this connection protocol_version = 0 #: Server details for this connection server = None in_use = False _closed = False _defunct = False #: The pool of which this connection is a member pool = None #: Error class used for raising connection errors # TODO: separate errors for connector API Error = ServiceUnavailable def __init__(self, protocol_version, unresolved_address, sock, **config): self.protocol_version = protocol_version self.unresolved_address = unresolved_address self.socket = sock self.server = ServerInfo(SocketAddress.from_socket(sock), protocol_version) self.outbox = Outbox() self.inbox = Inbox(BufferedSocket(self.socket, 32768), on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = config.get( "max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME) self._creation_timestamp = perf_counter() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", get_user_agent()) if isinstance(user_agent, bytes): user_agent = user_agent.decode("UTF-8") self.user_agent = user_agent # Determine auth details auth = config.get("auth") if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: self.auth_dict = vars(AuthToken("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None") # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get( "der_encoded_server_certificate") @property def secure(self): return isinstance(self.socket, SSLSocket) @property def local_port(self): try: return self.socket.getsockname()[1] except IOError: return 0 def hello(self): headers = {"user_agent": self.user_agent} headers.update(self.auth_dict) logged_headers = dict(headers) if "credentials" in logged_headers: logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers, ), response=InitResponse( self, on_success=self.server.metadata.update)) self.send_all() self.fetch_all() def __del__(self): try: self.close() except: pass def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def run(self, statement, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, **handlers): if not parameters: parameters = {} extra = {} if mode: extra["mode"] = mode if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError( "Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError( "Timeout must be specified as a number of seconds") fields = (statement, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) if statement.upper() == u"COMMIT": self._append(b"\x10", fields, CommitResponse(self, **handlers)) else: self._append(b"\x10", fields, Response(self, **handlers)) def discard_all(self, **handlers): log.debug("[#%04X] C: DISCARD_ALL", self.local_port) self._append(b"\x2F", (), Response(self, **handlers)) def pull_all(self, **handlers): log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append(b"\x3F", (), Response(self, **handlers)) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, **handlers): extra = {} if mode: extra["mode"] = mode if bookmarks: try: extra["bookmarks"] = list(bookmarks) except TypeError: raise TypeError( "Bookmarks must be provided within an iterable") if metadata: try: extra["tx_metadata"] = dict(metadata) except TypeError: raise TypeError("Metadata must be coercible to a dict") if timeout: try: extra["tx_timeout"] = int(1000 * timeout) except TypeError: raise TypeError( "Timeout must be specified as a number of seconds") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) self._append(b"\x11", (extra, ), Response(self, **handlers)) def commit(self, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) self._append(b"\x12", (), CommitResponse(self, **handlers)) def rollback(self, **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append(b"\x13", (), Response(self, **handlers)) def _append(self, signature, fields=(), response=None): """ Add a message to the outgoing queue. :arg signature: the signature of the message :arg fields: the fields of the message as a tuple :arg response: a response object to handle callbacks """ self.packer.pack_struct(signature, fields) self.outbox.chunk() self.outbox.chunk() self.responses.append(response) def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ def fail(metadata): raise ProtocolError("RESET failed %r" % metadata) log.debug("[#%04X] C: RESET", self.local_port) self._append(b"\x0F", response=Response(self, on_failure=fail)) self.send_all() self.fetch_all() def _send_all(self): data = self.outbox.view() if data: self.socket.sendall(data) self.outbox.clear() def send_all(self): """ Send all queued messages to the server. """ if self.closed(): raise self.Error("Failed to write to closed connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) if self.defunct(): raise self.Error("Failed to write to defunct connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) try: self._send_all() except (IOError, OSError) as error: log.error("Failed to write data to connection " "{!r} ({!r}); ({!r})".format( self.unresolved_address, self.server.address, "; ".join(map(repr, error.args)))) if self.pool: self.pool.deactivate(self.unresolved_address) raise def fetch_message(self): """ Receive at least one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ if self._closed: raise self.Error("Failed to read from closed connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) if self._defunct: raise self.Error("Failed to read from defunct connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) if not self.responses: return 0, 0 # Receive exactly one message try: details, summary_signature, summary_metadata = next(self.inbox) except (IOError, OSError) as error: log.error("Failed to read data from connection " "{!r} ({!r}); ({!r})".format( self.unresolved_address, self.server.address, "; ".join(map(repr, error.args)))) if self.pool: self.pool.deactivate(self.unresolved_address) raise if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # TODO self.responses[0].on_records(details) if summary_signature is None: return len(details), 0 response = self.responses.popleft() response.complete = True if summary_signature == b"\x70": log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": log.debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) try: response.on_failure(summary_metadata or {}) except (ConnectionExpired, ServiceUnavailable, DatabaseUnavailableError): if self.pool: self.pool.deactivate(self.unresolved_address), raise except (NotALeaderError, ForbiddenOnReadOnlyDatabaseError): if self.pool: self.pool.remove_writer(self.unresolved_address), raise else: raise ProtocolError("Unexpected response message with " "signature %02X" % summary_signature) return len(details), 1 def _set_defunct(self, error=None): message = ("Failed to read from defunct connection " "{!r} ({!r})".format(self.unresolved_address, self.server.address)) log.error(message) # We were attempting to receive data but the connection # has unexpectedly terminated. So, we need to close the # connection from the client side, and remove the address # from the connection pool. self._defunct = True self.close() if self.pool: self.pool.deactivate(self.unresolved_address) # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are # unable to confirm that the COMMIT completed successfully. for response in self.responses: if isinstance(response, CommitResponse): raise IncompleteCommitError(message) raise self.Error(message) def timedout(self): return 0 <= self._max_connection_lifetime <= perf_counter( ) - self._creation_timestamp def fetch_all(self): """ Fetch all outstanding messages. :return: 2-tuple of number of detail messages and number of summary messages fetched """ detail_count = summary_count = 0 while self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = self.fetch_message() detail_count += detail_delta summary_count += summary_delta return detail_count, summary_count def close(self): """ Close the connection. """ if not self._closed: if not self._defunct: log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", ()) try: self._send_all() except: pass log.debug("[#%04X] C: <CLOSE>", self.local_port) try: self.socket.close() except IOError: pass finally: self._closed = True def closed(self): return self._closed def defunct(self): return self._defunct