def test_address_resolve_with_custom_resolver_none(): # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver_none address = Address(("127.0.0.1", 7687)) resolved = address.resolve(resolver=None) assert isinstance(resolved, Address) is False assert isinstance(resolved, list) is True assert len(resolved) == 1 assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
async def _connect(cls, address, security, loop): """ Attempt to establish a TCP connection to the address provided. :param address: :param security: :param loop: :return: a 3-tuple of reader, writer and security settings for the new connection :raise BoltConnectionError: if a connection could not be established """ assert isinstance(address, Address) if loop is None: loop = get_event_loop() connection_args = { "host": address.host, "port": address.port, "family": address.family, # TODO: other args } if security is True: security = Security.default() if isinstance(security, Security): ssl_context = security.to_ssl_context() connection_args["ssl"] = ssl_context connection_args["server_hostname"] = address.host elif security: raise TypeError( "Unsupported security configuration {!r}".format(security)) else: security = None log.debug("[#0000] C: <DIAL> %s", address) try: reader = BoltStreamReader(loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) transport, _ = await loop.create_connection( lambda: protocol, **connection_args) writer = BoltStreamWriter(transport, protocol, reader, loop) except SSLError as err: log.debug("[#%04X] S: <REJECT> %s (%d %s)", 0, address, err.errno, strerror(err.errno)) raise BoltSecurityError("Failed to establish a secure connection", address) from err except OSError as err: log.debug("[#%04X] S: <REJECT> %s (%d %s)", 0, address, err.errno, strerror(err.errno)) raise BoltConnectionError("Failed to establish a connection", address) from err else: local_address = Address(transport.get_extra_info("sockname")) remote_address = Address(transport.get_extra_info("peername")) log.debug("[#%04X] S: <ACCEPT> %s -> %s", local_address.port_number, local_address, remote_address) return reader, writer, security
def test_address_resolve_with_custom_resolver(): # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver custom_resolver = lambda a: [("127.0.0.1", 7687), ("localhost", 1234)] address = Address(("127.0.0.1", 7687)) resolved = address.resolve(resolver=custom_resolver) assert isinstance(resolved, Address) is False assert isinstance(resolved, list) is True assert len(resolved) == 3 assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) assert resolved[1] == IPv6Address(('::1', 1234, 0, 0)) assert resolved[2] == IPv4Address(('127.0.0.1', 1234))
def test_address_resolve_with_custom_resolver(): # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver custom_resolver = lambda a: [("127.0.0.1", 7687), ("localhost", 1234)] address = Address(("127.0.0.1", 7687)) resolved = address.resolve(family=AF_INET, resolver=custom_resolver) assert isinstance(resolved, Address) is False assert isinstance(resolved, list) is True if len(resolved) == 2: # IPv4 only assert resolved[0] == IPv4Address(('127.0.0.1', 7687)) assert resolved[1] == IPv4Address(('127.0.0.1', 1234)) else: assert False
async def _connect(cls, address, loop, config): """ Attempt to establish a TCP connection to the address provided. :param address: :param loop: :param config: :return: a 3-tuple of reader, writer and security settings for the new connection :raise BoltConnectionError: if a connection could not be established """ assert isinstance(address, Address) assert loop is not None assert isinstance(config, PoolConfig) connection_args = { "host": address.host, "port": address.port, "family": address.family, # TODO: other args } ssl_context = config.get_ssl_context() if ssl_context: connection_args["ssl"] = ssl_context connection_args["server_hostname"] = address.host log.debug("[#0000] C: <DIAL> %s", address) try: reader = BoltStreamReader(loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) transport, _ = await loop.create_connection( lambda: protocol, **connection_args) writer = BoltStreamWriter(transport, protocol, reader, loop) except SSLError as err: log.debug("[#%04X] S: <REJECT> %s (%d %s)", 0, address, err.errno, strerror(err.errno)) raise BoltSecurityError("Failed to establish a secure connection", address) from err except OSError as err: log.debug("[#%04X] S: <REJECT> %s (%d %s)", 0, address, err.errno, strerror(err.errno)) raise BoltConnectionError("Failed to establish a connection", address) from err else: local_address = Address(transport.get_extra_info("sockname")) remote_address = Address(transport.get_extra_info("peername")) log.debug("[#%04X] S: <ACCEPT> %s -> %s", local_address.port_number, local_address, remote_address) return reader, writer
async def _handshake(cls, reader, writer, protocol_version): """ Carry out a Bolt handshake, optionally requesting a specific protocol version. :param reader: :param writer: :param protocol_version: :return: :raise BoltConnectionLost: if an I/O error occurs on the underlying socket connection :raise BoltHandshakeError: if handshake completes without a successful negotiation """ local_address = Address(writer.transport.get_extra_info("sockname")) remote_address = Address(writer.transport.get_extra_info("peername")) handlers = cls.protocol_handlers(protocol_version) if not handlers: raise ValueError( "No protocol handlers available (requested Bolt %r)", protocol_version) offered_versions = sorted(handlers.keys(), reverse=True)[:4] request_data = MAGIC + b"".join( v.to_bytes() for v in offered_versions).ljust(16, b"\x00") log.debug("[#%04X] C: <HANDSHAKE> %r", local_address.port_number, request_data) writer.write(request_data) await writer.drain() response_data = await reader.readexactly(4) log.debug("[#%04X] S: <HANDSHAKE> %r", local_address.port_number, response_data) try: agreed_version = Version.from_bytes(response_data) except ValueError as err: writer.close() raise BoltHandshakeError( "Unexpected handshake response %r" % response_data, remote_address, request_data, response_data) from err try: subclass = handlers[agreed_version] except KeyError: log.debug("Unsupported Bolt protocol version %s", agreed_version) raise BoltHandshakeError("Unsupported Bolt protocol version", remote_address, request_data, response_data) else: return subclass
async def a_main(prog): parser = ArgumentParser(prog=prog) parser.add_argument("cypher", help="Cypher query to execute") parser.add_argument("-a", "--auth", metavar="USER:PASSWORD", default="", help="user name and password") parser.add_argument("-s", "--server-addr", metavar="HOST:PORT", default=":7687", help="address of server") parser.add_argument("-v", "--verbose", action="store_true", help="increase output verbosity") parsed = parser.parse_args() if parsed.verbose: watch("neo4j") addr = Address.parse(parsed.server_addr) user, _, password = parsed.auth.partition(":") if not password: password = getpass() auth = (user or "neo4j", password) bolt = await Bolt.open(addr, auth=auth) try: result = await bolt.run(parsed.cypher) print("\t".join(await result.fields())) async for record in result: print("\t".join(map(repr, record))) finally: await bolt.close()
def __new__(cls, uri, **config): cls._check_uri(uri) instance = object.__new__(cls) parsed = urlparse(uri) instance.initial_address = initial_address = \ Address.parse(parsed.netloc, default_port=DEFAULT_PORT) if config.get("encrypted") is None: config["encrypted"] = False instance._ssl_context = make_ssl_context(**config) instance.encrypted = instance._ssl_context is not None routing_context = cls.parse_routing_context(uri) def connector(address, **kwargs): return Connection.open(address, **dict(config, **kwargs)) pool = RoutingConnectionPool(connector, initial_address, routing_context, initial_address, **config) try: pool.update_routing_table() except Exception: pool.close() raise else: instance._pool = pool instance._max_retry_time = \ config.get("max_retry_time", default_config["max_retry_time"]) return instance
def __new__(cls, uri, **config): cls._check_uri(uri) instance = object.__new__(cls) # We keep the address containing the host name or IP address exactly # as-is from the original URI. This means that every new connection # will carry out DNS resolution, leading to the possibility that # the connection pool may contain multiple IP address keys, one for # an old address and one for a new address. parsed = urlparse(uri) instance.address = Address.parse(parsed.netloc, default_port=DEFAULT_PORT) if config.get("encrypted") is None: config["encrypted"] = False instance._ssl_context = make_ssl_context(**config) instance.encrypted = instance._ssl_context is not None def connector(address, **kwargs): return Connection.open(address, **dict(config, **kwargs)) pool = ConnectionPool(connector, instance.address, **config) pool.release(pool.acquire()) instance._pool = pool instance._max_retry_time = config.get("max_retry_time", default_config["max_retry_time"]) return instance
def parse_routing_info(cls, records): """ Parse the records returned from a getServers call and return a new RoutingTable instance. """ if len(records) != 1: raise RoutingProtocolError("Expected exactly one record") record = records[0] routers = [] readers = [] writers = [] try: servers = record["servers"] for server in servers: role = server["role"] addresses = [] for address in server["addresses"]: addresses.append(Address.parse(address, default_port=DEFAULT_PORT)) if role == "ROUTE": routers.extend(addresses) elif role == "READ": readers.extend(addresses) elif role == "WRITE": writers.extend(addresses) ttl = record["ttl"] except (KeyError, TypeError): raise RoutingProtocolError("Cannot parse routing info") else: return cls(routers, readers, writers, ttl)
def parse_routing_info(cls, *, database, servers, ttl): """ Parse the records returned from the procedure call and return a new RoutingTable instance. """ routers = [] readers = [] writers = [] try: for server in servers: role = server["role"] addresses = [] for address in server["addresses"]: addresses.append(Address.parse(address, default_port=7687)) if role == "ROUTE": routers.extend(addresses) elif role == "READ": readers.extend(addresses) elif role == "WRITE": writers.extend(addresses) except (KeyError, TypeError): raise ValueError("Cannot parse routing info") else: return cls(database=database, routers=routers, readers=readers, writers=writers, ttl=ttl)
def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ last_error = None # Establish a connection to the host and port specified # Catches refused connections see: # https://docs.python.org/2/library/errno.html log.debug("[#0000] C: <RESOLVE> %s", address) for resolved_address in Address(address).resolve(resolver=custom_resolver): s = None try: host = address[0] s = _connect(resolved_address, timeout, keep_alive) s = _secure(s, host, ssl_context) return _handshake(s, address) except Exception as error: if s: s.close() last_error = error if last_error is None: raise ServiceUnavailable("Failed to resolve addresses for %s" % address) else: raise last_error
def parse_target(cls, target): """ Parse a target string to produce an address. """ if not target: target = cls.default_target address = Address.parse(target, default_host=cls.default_host, default_port=cls.default_port) return address
def test_address_initialization(test_input, expected): # python -m pytest tests/unit/test_addressing.py -s -k test_address_initialization address = Address(test_input) assert address.family == expected["family"] assert address.host == expected["host"] assert address.port == expected["port"] assert str(address) == expected["str"] assert repr(address) == expected["repr"]
def parse_targets(cls, *targets): """ Parse a sequence of target strings to produce an address list. """ targets = " ".join(targets) if not targets: targets = cls.default_targets addresses = Address.parse_list(targets, default_host=cls.default_host, default_port=cls.default_port) return addresses
def __init__(self, loop, opener, config, address): if loop is None: self._loop = get_event_loop() else: self._loop = loop self._opener = opener self._address = Address(address) self._max_size = config.max_size self._max_age = config.max_age self._in_use_list = deque() self._free_list = deque() self._waiting_list = WaitingList(loop=self._loop)
def test_serverinfo_initialization(): # python -m pytest tests/unit/test_api.py -s -k test_serverinfo_initialization from neo4j.addressing import Address address = Address(("bolt://localhost", 7687)) version = neo4j.api.Version(3, 0) server_info = neo4j.api.ServerInfo(address, version) assert server_info.address is address assert server_info.protocol_version is version assert server_info.agent is None assert server_info.connection_id is None
def test_serverinfo_with_metadata(test_input, expected_agent, expected_version_info): # python -m pytest tests/unit/test_api.py -s -k test_serverinfo_with_metadata from neo4j.addressing import Address address = Address(("bolt://localhost", 7687)) version = neo4j.api.Version(3, 0) server_info = neo4j.api.ServerInfo(address, version) server_info.update(test_input) assert server_info.agent == expected_agent assert server_info.version_info() == expected_version_info
def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None): self.unresolved_address = unresolved_address self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION) self.outbox = Outbox() self.inbox = Inbox(self.socket, on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = max_connection_lifetime self._creation_timestamp = perf_counter() self.supports_multiple_results = False self.supports_multiple_databases = False self._is_reset = True self.routing_context = routing_context # Determine the user agent if user_agent: self.user_agent = user_agent else: self.user_agent = get_user_agent() # Determine auth details if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j import Auth self.auth_dict = vars(Auth("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None")
async def open(cls, address, *, auth=None, security=False, protocol_version=None, loop=None): """ Open a socket connection and perform protocol version negotiation, in order to construct and return a Bolt client instance for a supported Bolt protocol version. :param address: tuples of host and port, such as ("127.0.0.1", 7687) :param auth: :param security: :param protocol_version: :param loop: :return: instance of a Bolt subclass :raise BoltConnectionError: if a connection could not be established :raise BoltConnectionLost: if an I/O error occurs on the underlying socket connection :raise BoltHandshakeError: if handshake completes without a successful negotiation :raise TypeError: if any of the arguments provided are passed as incompatible types :raise ValueError: if any of the arguments provided are passed with unsupported values """ # Connect address = Address(address) reader, writer, security = await cls._connect(address, security, loop) try: # Handshake subclass = await cls._handshake(reader, writer, protocol_version) # Instantiation inst = subclass(reader, writer) inst.security = security assert hasattr(inst, "__ainit__") await inst.__ainit__(auth) return inst except BoltError: writer.write_eof() writer.close() raise
async def open(cls, address, *, auth=None, loop=None, **config): """ Open a socket connection and perform protocol version negotiation, in order to construct and return a Bolt client instance for a supported Bolt protocol version. :param address: tuples of host and port, such as ("127.0.0.1", 7687) :param auth: :param loop: :param config: :return: instance of a Bolt subclass :raise BoltConnectionError: if a connection could not be established :raise BoltConnectionLost: if an I/O error occurs on the underlying socket connection :raise BoltHandshakeError: if handshake completes without a successful negotiation :raise TypeError: if any of the arguments provided are passed as incompatible types :raise ValueError: if any of the arguments provided are passed with unsupported values """ # Args address = Address(address) if loop is None: loop = get_event_loop() config = PoolConfig.consume(config) # Connect reader, writer = await cls._connect(address, loop, config) try: # Handshake subclass = await cls._handshake(reader, writer, config.protocol_version) # Instantiation obj = subclass(reader, writer) obj.secure = bool(config.secure) assert hasattr(obj, "__ainit__") await obj.__ainit__(auth) return obj except BoltError: writer.write_eof() writer.close() raise
def __init__(self, unresolved_address, sock, *, auth=None, protocol_version=None, **config): self.config = PoolConfig.consume(config) self.protocol_version = protocol_version self.unresolved_address = unresolved_address self.socket = sock self.server = ServerInfo(Address(sock.getpeername()), protocol_version) self.outbox = Outbox() self.inbox = Inbox(BufferedSocket(self.socket, 32768), on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = self.config.max_age self._creation_timestamp = perf_counter() # Determine the user agent user_agent = self.config.user_agent if user_agent: self.user_agent = user_agent else: self.user_agent = get_user_agent() # Determine auth details if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: from neo4j import Auth self.auth_dict = vars(Auth("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None")
def __init__(self, protocol_version, unresolved_address, sock, **config): self.protocol_version = protocol_version self.unresolved_address = unresolved_address self.socket = sock self.server = ServerInfo(Address(sock.getpeername()), protocol_version) self.outbox = Outbox() self.inbox = Inbox(BufferedSocket(self.socket, 32768), on_error=self._set_defunct) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() self._max_connection_lifetime = config.get( "max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME) self._creation_timestamp = perf_counter() # Determine the user agent user_agent = config.get("user_agent") if user_agent: self.user_agent = user_agent else: self.user_agent = get_user_agent() # Determine auth details auth = config.get("auth") if not auth: self.auth_dict = {} elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: self.auth_dict = vars(AuthToken("basic", *auth)) else: try: self.auth_dict = vars(auth) except (KeyError, TypeError): raise AuthError("Cannot determine auth details from %r" % auth) # Check for missing password try: credentials = self.auth_dict["credentials"] except KeyError: pass else: if credentials is None: raise AuthError("Password cannot be None") # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get( "der_encoded_server_certificate")
def test_address_parse_with_invalid_input(test_input, expected): # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_with_invalid_input with pytest.raises(expected): parsed = Address.parse(test_input)
def test_address_parse_with_ipv4(test_input, expected): # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_with_ipv4 parsed = Address.parse(test_input) assert parsed == expected
def local_address(self): return Address(self.__transport.get_extra_info("sockname"))
def test_address_should_parse_ipv6(test_input, expected): # python -m pytest tests/unit/test_addressing.py -s -k test_address_should_parse_ipv6 parsed = Address.parse(test_input) assert parsed == expected
def test_address_parse_list_with_invalid_input(test_input, expected): # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_list_with_invalid_input with pytest.raises(TypeError): addresses = Address.parse_list(*test_input)
def remote_address(self): return Address(self.__transport.get_extra_info("peername"))
def test_address_parse_list(test_input, expected): # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_list addresses = Address.parse_list(*test_input) assert len(addresses) == expected