def test_parsing_empty_string(): a = AddressList.parse("") assert a == []
def test_parsing_empty_host_and_port_with_both_defaults(): a = AddressList.parse(":", default_host="x", default_port=80) assert a == [('x', 80)]
def test_illegal_type_in_parsing(): with raises(TypeError): _ = AddressList.parse(object())
def test_parsing_port_only_with_default_host(): a = AddressList.parse(":http", default_host="x") assert a == [('x', 'http')]
def test_parsing_empty_host_and_port(): a = AddressList.parse(":") assert a == [('localhost', 0)]
def test_parsing_host_only_with_default_port(): a = AddressList.parse("localhost", default_port=80) assert a == [('localhost', 80)]
def test_parsing_port_only(): a = AddressList.parse(":http") assert a == [('localhost', 'http')]
def test_parsing_host_and_port(): a = AddressList.parse("localhost:http") assert a == [('localhost', 'http')]
def test_parsing_host_only(): a = AddressList.parse("localhost") assert a == [('localhost', 0)]
def test_parsing_multiple_addresses(): a = AddressList.parse("127.0.0.1:80 [::1]:80") assert a == [('127.0.0.1', '80'), ('::1', '80', 0, 0)]
def test_parsing_ipv6_address_list(): a = AddressList.parse("[::1]:80") assert a == [('::1', '80', 0, 0)]
def test_parsing_ipv4_address_list(): a = AddressList.parse("127.0.0.1:80") assert a == [('127.0.0.1', '80')]
class Connection: """ The Connection wraps a socket through which protocol messages are sent and received. The socket is owned by this Connection instance... """ # Maximum size of a single data chunk. max_chunk_size = 65535 # The default address list to use if no addresses are specified. default_address_list = AddressList.parse(":7687 :17601 :17687") @classmethod def default_user_agent(cls): """ Return the default user agent string for a Connection. """ from grolt.meta import package, version return "{}/{}".format(package, version) @classmethod def fix_bolt_versions(cls, bolt_versions): """ Using the requested Bolt versions, and falling back on the full list available, generate a tuple of exactly four Bolt protocol versions for use in version negotiation. """ # Establish which protocol versions we want to attempt to use if not bolt_versions: bolt_versions = sorted( [v for v, x in enumerate(CLIENT) if x is not None], reverse=True) # Raise an error if we're including any non-supported versions if any(v < 0 or v > MAX_BOLT_VERSION for v in bolt_versions): raise ValueError("This client does not support all " "Bolt versions in %r" % bolt_versions) # Ensure we send exactly 4 versions, padding with zeroes if necessary return tuple(list(bolt_versions) + [0, 0, 0, 0])[:4] @classmethod def _open_to(cls, address, auth, user_agent, bolt_versions): """ Attempt to open a connection to a Bolt server, given a single socket address. """ cx = None handshake_data = BOLT + b"".join( raw_pack(UINT_32, version) for version in bolt_versions) s = socket(family={2: AF_INET, 4: AF_INET6}[len(address)]) try: s.connect(address) s.sendall(handshake_data) raw_bolt_version = s.recv(4) if raw_bolt_version: bolt_version, = raw_unpack(UINT_32, raw_bolt_version) if bolt_version > 0 and bolt_version in bolt_versions: cx = cls(s, bolt_version, auth, user_agent) else: log.error( "Could not negotiate protocol version " "(outcome=%d)", bolt_version) else: pass # recv returned empty, peer closed connection finally: if not cx: s.close() return cx @classmethod def open(cls, *addresses, auth, user_agent=None, bolt_versions=None, timeout=0): """ Open a connection to a Bolt server. It is here that we create a low-level socket connection and carry out version negotiation. Following this (and assuming success) a Connection instance will be returned. This Connection takes ownership of the underlying socket and is subsequently responsible for managing its lifecycle. Args: addresses: Tuples of host and port, such as ("127.0.0.1", 7687). auth: user_agent: bolt_versions: timeout: Returns: A connection to the Bolt server. Raises: ProtocolError: if the protocol version could not be negotiated. """ addresses = AddressList(addresses or cls.default_address_list) addresses.resolve() t0 = perf_counter() bolt_versions = cls.fix_bolt_versions(bolt_versions) log.debug("Trying to open connection to «%s»", addresses) errors = set() again = True wait = 0.1 while again: for address in addresses: try: cx = cls._open_to(address, auth, user_agent, bolt_versions) except OSError as e: errors.add(" ".join(map(str, e.args))) else: if cx: return cx again = perf_counter() - t0 < (timeout or 0) if again: sleep(wait) wait *= 2 log.error("Could not open connection to «%s» (%r)", addresses, errors) raise OSError("Could not open connection") closed = False def __init__(self, s, bolt_version, auth, user_agent=None): self.socket = s self.address = AddressList([self.socket.getpeername()]) self.bolt_version = bolt_version log.debug("Opened connection to «%s» using Bolt v%d", self.address, self.bolt_version) self.requests = [] self.responses = [] try: user, password = auth except (TypeError, ValueError): user, password = "******", "" if user_agent is None: user_agent = self.default_user_agent() if bolt_version >= 3: args = { "scheme": "basic", "principal": user, "credentials": password, "user_agent": user_agent, } log.debug("C: HELLO %r" % dict(args, credentials="...")) request = Structure(CLIENT[self.bolt_version]["HELLO"], args) else: auth_token = { "scheme": "basic", "principal": user, "credentials": password, } log.debug("C: INIT %r %r", user_agent, dict(auth_token, credentials="...")) request = Structure(CLIENT[self.bolt_version]["INIT"], user_agent, auth_token) self.requests.append(request) response = Response(self) self.responses.append(response) self.send_all() self.fetch_all() self.server_agent = response.metadata["server"] def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def close(self): if not self.closed: log.debug("Closing connection to «%s»", self.address) self.socket.close() self.closed = True def reset(self): log.debug("C: RESET") self.requests.append(Structure(CLIENT[self.bolt_version]["RESET"])) self.send_all() response = Response(self) self.responses.append(response) def run(self, cypher, parameters=None, metadata=None): parameters = parameters or {} metadata = metadata or {} if self.bolt_version >= 3: log.debug("C: RUN %r %r %r", cypher, parameters, metadata) run = Structure(CLIENT[self.bolt_version]["RUN"], cypher, parameters, metadata) elif metadata: raise ProtocolError("RUN metadata is not available in Bolt v%d" % self.bolt_version) else: log.debug("C: RUN %r %r", cypher, parameters) run = Structure(CLIENT[self.bolt_version]["RUN"], cypher, parameters) self.requests.append(run) response = QueryResponse(self) self.responses.append(response) return response def discard(self, n, qid): """ Enqueue a DISCARD message. :param n: number of records to discard (-1 means all) :param qid: the query for which to discard records (-1 means the query immediately preceding) :return: :class:`.QueryResponse` object """ v = self.bolt_version if v >= 4: args = {"n": n} if qid >= 0: args["qid"] = qid log.debug("C: DISCARD %r", args) self.requests.append(Structure(CLIENT[v]["DISCARD"], args)) elif n >= 0 or qid >= 0: raise ProtocolError("Reactive DISCARD is not available in " "Bolt v%d" % v) else: log.debug("C: DISCARD_ALL") self.requests.append(Structure(CLIENT[v]["DISCARD_ALL"])) response = QueryResponse(self) self.responses.append(response) return response def pull(self, n, qid, records): """ Enqueue a PULL message. :param n: number of records to pull (-1 means all) :param qid: the query for which to pull records (-1 means the query immediately preceding) :param records: list-like container into which records may be appended :return: :class:`.QueryResponse` object """ v = self.bolt_version if v >= 4: args = {"n": n} if qid >= 0: args["qid"] = qid log.debug("C: PULL %r", args) self.requests.append(Structure(CLIENT[v]["PULL"], args)) elif n >= 0 or qid >= 0: raise ProtocolError("Reactive PULL is not available in " "Bolt v%d" % v) else: log.debug("C: PULL_ALL") self.requests.append(Structure(CLIENT[v]["PULL_ALL"])) response = QueryResponse(self, records) self.responses.append(response) return response def begin(self, metadata=None): metadata = metadata or {} if self.bolt_version >= 3: log.debug("C: BEGIN %r", metadata) self.requests.append( Structure(CLIENT[self.bolt_version]["BEGIN"], metadata)) else: raise ProtocolError("BEGIN is not available in Bolt v%d" % self.bolt_version) response = QueryResponse(self) self.responses.append(response) return response def commit(self): if self.bolt_version >= 3: log.debug("C: COMMIT") self.requests.append(Structure( CLIENT[self.bolt_version]["COMMIT"])) else: raise ProtocolError("COMMIT is not available in Bolt v%d" % self.bolt_version) response = QueryResponse(self) self.responses.append(response) return response def rollback(self): if self.bolt_version >= 3: log.debug("C: ROLLBACK") self.requests.append( Structure(CLIENT[self.bolt_version]["ROLLBACK"])) else: raise ProtocolError("ROLLBACK is not available in Bolt v%d" % self.bolt_version) response = QueryResponse(self) self.responses.append(response) return response def send_all(self): """ Send all pending request messages to the server. """ if not self.requests: return data = [] while self.requests: request = self.requests.pop(0) request_data = pack(request) for offset in range(0, len(request_data), self.max_chunk_size): end = offset + self.max_chunk_size chunk = request_data[offset:end] data.append(raw_pack(UINT_16, len(chunk))) data.append(chunk) data.append(raw_pack(UINT_16, 0)) self.socket.sendall(b"".join(data)) def fetch_one(self): """ Receive exactly one response message from the server. This method blocks until either a message arrives or the connection is terminated. """ # Receive chunks of data until chunk_size == 0 data = [] chunk_size = -1 while chunk_size != 0: chunk_size, = raw_unpack(UINT_16, self.socket.recv(2)) if chunk_size > 0: data.append(self.socket.recv(chunk_size)) message = unpack(b"".join(data)) # Handle message response = self.responses[0] response.on_message(message.tag, *message.fields) if response.complete: self.responses.pop(0) def fetch_summary(self): """ Fetch all messages up to and including the next summary message. """ response = self.responses[0] while not response.complete and not self.closed: self.fetch_one() def fetch_all(self): """ Fetch all messages from all outstanding responses. """ while self.responses and not self.closed: self.fetch_summary()
def convert(self, value, param, ctx): return AddressList.parse(value, self.default_host, self.default_port)