예제 #1
0
def test_address_resolve_with_custom_resolver_none():
    # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver_none
    address = Address(("127.0.0.1", 7687))
    resolved = address.resolve(resolver=None)
    assert isinstance(resolved, Address) is False
    assert isinstance(resolved, list) is True
    assert len(resolved) == 1
    assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
예제 #2
0
    async def _connect(cls, address, security, loop):
        """ Attempt to establish a TCP connection to the address
        provided.

        :param address:
        :param security:
        :param loop:
        :return: a 3-tuple of reader, writer and security settings for
            the new connection
        :raise BoltConnectionError: if a connection could not be
            established
        """
        assert isinstance(address, Address)
        if loop is None:
            loop = get_event_loop()
        connection_args = {
            "host": address.host,
            "port": address.port,
            "family": address.family,
            # TODO: other args
        }
        if security is True:
            security = Security.default()
        if isinstance(security, Security):
            ssl_context = security.to_ssl_context()
            connection_args["ssl"] = ssl_context
            connection_args["server_hostname"] = address.host
        elif security:
            raise TypeError(
                "Unsupported security configuration {!r}".format(security))
        else:
            security = None
        log.debug("[#0000] C: <DIAL> %s", address)
        try:
            reader = BoltStreamReader(loop=loop)
            protocol = StreamReaderProtocol(reader, loop=loop)
            transport, _ = await loop.create_connection(
                lambda: protocol, **connection_args)
            writer = BoltStreamWriter(transport, protocol, reader, loop)
        except SSLError as err:
            log.debug("[#%04X] S: <REJECT> %s (%d %s)", 0, address, err.errno,
                      strerror(err.errno))
            raise BoltSecurityError("Failed to establish a secure connection",
                                    address) from err
        except OSError as err:
            log.debug("[#%04X] S: <REJECT> %s (%d %s)", 0, address, err.errno,
                      strerror(err.errno))
            raise BoltConnectionError("Failed to establish a connection",
                                      address) from err
        else:
            local_address = Address(transport.get_extra_info("sockname"))
            remote_address = Address(transport.get_extra_info("peername"))
            log.debug("[#%04X] S: <ACCEPT> %s -> %s",
                      local_address.port_number, local_address, remote_address)
            return reader, writer, security
def test_address_resolve_with_custom_resolver():
    # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver
    custom_resolver = lambda a: [("127.0.0.1", 7687), ("localhost", 1234)]

    address = Address(("127.0.0.1", 7687))
    resolved = address.resolve(resolver=custom_resolver)
    assert isinstance(resolved, Address) is False
    assert isinstance(resolved, list) is True
    assert len(resolved) == 3
    assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
    assert resolved[1] == IPv6Address(('::1', 1234, 0, 0))
    assert resolved[2] == IPv4Address(('127.0.0.1', 1234))
def test_address_resolve_with_custom_resolver():
    # python -m pytest tests/unit/test_addressing.py -s -k test_address_resolve_with_custom_resolver
    custom_resolver = lambda a: [("127.0.0.1", 7687), ("localhost", 1234)]

    address = Address(("127.0.0.1", 7687))
    resolved = address.resolve(family=AF_INET, resolver=custom_resolver)
    assert isinstance(resolved, Address) is False
    assert isinstance(resolved, list) is True
    if len(resolved) == 2:
        # IPv4 only
        assert resolved[0] == IPv4Address(('127.0.0.1', 7687))
        assert resolved[1] == IPv4Address(('127.0.0.1', 1234))
    else:
        assert False
    async def _connect(cls, address, loop, config):
        """ Attempt to establish a TCP connection to the address
        provided.

        :param address:
        :param loop:
        :param config:
        :return: a 3-tuple of reader, writer and security settings for
            the new connection
        :raise BoltConnectionError: if a connection could not be
            established
        """
        assert isinstance(address, Address)
        assert loop is not None
        assert isinstance(config, PoolConfig)
        connection_args = {
            "host": address.host,
            "port": address.port,
            "family": address.family,
            # TODO: other args
        }
        ssl_context = config.get_ssl_context()
        if ssl_context:
            connection_args["ssl"] = ssl_context
            connection_args["server_hostname"] = address.host
        log.debug("[#0000] C: <DIAL> %s", address)
        try:
            reader = BoltStreamReader(loop=loop)
            protocol = StreamReaderProtocol(reader, loop=loop)
            transport, _ = await loop.create_connection(
                lambda: protocol, **connection_args)
            writer = BoltStreamWriter(transport, protocol, reader, loop)
        except SSLError as err:
            log.debug("[#%04X] S: <REJECT> %s (%d %s)", 0, address, err.errno,
                      strerror(err.errno))
            raise BoltSecurityError("Failed to establish a secure connection",
                                    address) from err
        except OSError as err:
            log.debug("[#%04X] S: <REJECT> %s (%d %s)", 0, address, err.errno,
                      strerror(err.errno))
            raise BoltConnectionError("Failed to establish a connection",
                                      address) from err
        else:
            local_address = Address(transport.get_extra_info("sockname"))
            remote_address = Address(transport.get_extra_info("peername"))
            log.debug("[#%04X] S: <ACCEPT> %s -> %s",
                      local_address.port_number, local_address, remote_address)
            return reader, writer
예제 #6
0
    async def _handshake(cls, reader, writer, protocol_version):
        """ Carry out a Bolt handshake, optionally requesting a
        specific protocol version.

        :param reader:
        :param writer:
        :param protocol_version:
        :return:
        :raise BoltConnectionLost: if an I/O error occurs on the
            underlying socket connection
        :raise BoltHandshakeError: if handshake completes without a
            successful negotiation
        """
        local_address = Address(writer.transport.get_extra_info("sockname"))
        remote_address = Address(writer.transport.get_extra_info("peername"))

        handlers = cls.protocol_handlers(protocol_version)
        if not handlers:
            raise ValueError(
                "No protocol handlers available (requested Bolt %r)",
                protocol_version)
        offered_versions = sorted(handlers.keys(), reverse=True)[:4]

        request_data = MAGIC + b"".join(
            v.to_bytes() for v in offered_versions).ljust(16, b"\x00")
        log.debug("[#%04X] C: <HANDSHAKE> %r", local_address.port_number,
                  request_data)
        writer.write(request_data)
        await writer.drain()
        response_data = await reader.readexactly(4)
        log.debug("[#%04X] S: <HANDSHAKE> %r", local_address.port_number,
                  response_data)
        try:
            agreed_version = Version.from_bytes(response_data)
        except ValueError as err:
            writer.close()
            raise BoltHandshakeError(
                "Unexpected handshake response %r" % response_data,
                remote_address, request_data, response_data) from err
        try:
            subclass = handlers[agreed_version]
        except KeyError:
            log.debug("Unsupported Bolt protocol version %s", agreed_version)
            raise BoltHandshakeError("Unsupported Bolt protocol version",
                                     remote_address, request_data,
                                     response_data)
        else:
            return subclass
예제 #7
0
async def a_main(prog):
    parser = ArgumentParser(prog=prog)
    parser.add_argument("cypher", help="Cypher query to execute")
    parser.add_argument("-a",
                        "--auth",
                        metavar="USER:PASSWORD",
                        default="",
                        help="user name and password")
    parser.add_argument("-s",
                        "--server-addr",
                        metavar="HOST:PORT",
                        default=":7687",
                        help="address of server")
    parser.add_argument("-v",
                        "--verbose",
                        action="store_true",
                        help="increase output verbosity")
    parsed = parser.parse_args()
    if parsed.verbose:
        watch("neo4j")
    addr = Address.parse(parsed.server_addr)
    user, _, password = parsed.auth.partition(":")
    if not password:
        password = getpass()
    auth = (user or "neo4j", password)
    bolt = await Bolt.open(addr, auth=auth)
    try:
        result = await bolt.run(parsed.cypher)
        print("\t".join(await result.fields()))
        async for record in result:
            print("\t".join(map(repr, record)))
    finally:
        await bolt.close()
예제 #8
0
    def __new__(cls, uri, **config):
        cls._check_uri(uri)
        instance = object.__new__(cls)
        parsed = urlparse(uri)
        instance.initial_address = initial_address = \
            Address.parse(parsed.netloc, default_port=DEFAULT_PORT)
        if config.get("encrypted") is None:
            config["encrypted"] = False
        instance._ssl_context = make_ssl_context(**config)
        instance.encrypted = instance._ssl_context is not None
        routing_context = cls.parse_routing_context(uri)

        def connector(address, **kwargs):
            return Connection.open(address, **dict(config, **kwargs))

        pool = RoutingConnectionPool(connector, initial_address,
                                     routing_context, initial_address,
                                     **config)
        try:
            pool.update_routing_table()
        except Exception:
            pool.close()
            raise
        else:
            instance._pool = pool
            instance._max_retry_time = \
                config.get("max_retry_time", default_config["max_retry_time"])
            return instance
예제 #9
0
    def __new__(cls, uri, **config):
        cls._check_uri(uri)
        instance = object.__new__(cls)
        # We keep the address containing the host name or IP address exactly
        # as-is from the original URI. This means that every new connection
        # will carry out DNS resolution, leading to the possibility that
        # the connection pool may contain multiple IP address keys, one for
        # an old address and one for a new address.
        parsed = urlparse(uri)
        instance.address = Address.parse(parsed.netloc,
                                         default_port=DEFAULT_PORT)
        if config.get("encrypted") is None:
            config["encrypted"] = False
        instance._ssl_context = make_ssl_context(**config)
        instance.encrypted = instance._ssl_context is not None

        def connector(address, **kwargs):
            return Connection.open(address, **dict(config, **kwargs))

        pool = ConnectionPool(connector, instance.address, **config)
        pool.release(pool.acquire())
        instance._pool = pool
        instance._max_retry_time = config.get("max_retry_time",
                                              default_config["max_retry_time"])
        return instance
예제 #10
0
 def parse_routing_info(cls, records):
     """ Parse the records returned from a getServers call and
     return a new RoutingTable instance.
     """
     if len(records) != 1:
         raise RoutingProtocolError("Expected exactly one record")
     record = records[0]
     routers = []
     readers = []
     writers = []
     try:
         servers = record["servers"]
         for server in servers:
             role = server["role"]
             addresses = []
             for address in server["addresses"]:
                 addresses.append(Address.parse(address, default_port=DEFAULT_PORT))
             if role == "ROUTE":
                 routers.extend(addresses)
             elif role == "READ":
                 readers.extend(addresses)
             elif role == "WRITE":
                 writers.extend(addresses)
         ttl = record["ttl"]
     except (KeyError, TypeError):
         raise RoutingProtocolError("Cannot parse routing info")
     else:
         return cls(routers, readers, writers, ttl)
예제 #11
0
 def parse_routing_info(cls, *, database, servers, ttl):
     """ Parse the records returned from the procedure call and
     return a new RoutingTable instance.
     """
     routers = []
     readers = []
     writers = []
     try:
         for server in servers:
             role = server["role"]
             addresses = []
             for address in server["addresses"]:
                 addresses.append(Address.parse(address, default_port=7687))
             if role == "ROUTE":
                 routers.extend(addresses)
             elif role == "READ":
                 readers.extend(addresses)
             elif role == "WRITE":
                 writers.extend(addresses)
     except (KeyError, TypeError):
         raise ValueError("Cannot parse routing info")
     else:
         return cls(database=database,
                    routers=routers,
                    readers=readers,
                    writers=writers,
                    ttl=ttl)
예제 #12
0
def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive):
    """ Connect and perform a handshake and return a valid Connection object,
    assuming a protocol version can be agreed.
    """
    last_error = None
    # Establish a connection to the host and port specified
    # Catches refused connections see:
    # https://docs.python.org/2/library/errno.html
    log.debug("[#0000]  C: <RESOLVE> %s", address)

    for resolved_address in Address(address).resolve(resolver=custom_resolver):
        s = None
        try:
            host = address[0]
            s = _connect(resolved_address, timeout, keep_alive)
            s = _secure(s, host, ssl_context)
            return _handshake(s, address)
        except Exception as error:
            if s:
                s.close()
            last_error = error
    if last_error is None:
        raise ServiceUnavailable("Failed to resolve addresses for %s" % address)
    else:
        raise last_error
예제 #13
0
 def parse_target(cls, target):
     """ Parse a target string to produce an address.
     """
     if not target:
         target = cls.default_target
     address = Address.parse(target, default_host=cls.default_host,
                             default_port=cls.default_port)
     return address
예제 #14
0
def test_address_initialization(test_input, expected):
    # python -m pytest tests/unit/test_addressing.py -s -k test_address_initialization
    address = Address(test_input)
    assert address.family == expected["family"]
    assert address.host == expected["host"]
    assert address.port == expected["port"]
    assert str(address) == expected["str"]
    assert repr(address) == expected["repr"]
예제 #15
0
 def parse_targets(cls, *targets):
     """ Parse a sequence of target strings to produce an address
     list.
     """
     targets = " ".join(targets)
     if not targets:
         targets = cls.default_targets
     addresses = Address.parse_list(targets, default_host=cls.default_host,
                                    default_port=cls.default_port)
     return addresses
 def __init__(self, loop, opener, config, address):
     if loop is None:
         self._loop = get_event_loop()
     else:
         self._loop = loop
     self._opener = opener
     self._address = Address(address)
     self._max_size = config.max_size
     self._max_age = config.max_age
     self._in_use_list = deque()
     self._free_list = deque()
     self._waiting_list = WaitingList(loop=self._loop)
예제 #17
0
def test_serverinfo_initialization():
    # python -m pytest tests/unit/test_api.py -s -k test_serverinfo_initialization

    from neo4j.addressing import Address

    address = Address(("bolt://localhost", 7687))
    version = neo4j.api.Version(3, 0)

    server_info = neo4j.api.ServerInfo(address, version)
    assert server_info.address is address
    assert server_info.protocol_version is version
    assert server_info.agent is None
    assert server_info.connection_id is None
예제 #18
0
def test_serverinfo_with_metadata(test_input, expected_agent, expected_version_info):
    # python -m pytest tests/unit/test_api.py -s -k test_serverinfo_with_metadata
    from neo4j.addressing import Address

    address = Address(("bolt://localhost", 7687))
    version = neo4j.api.Version(3, 0)

    server_info = neo4j.api.ServerInfo(address, version)

    server_info.update(test_input)

    assert server_info.agent == expected_agent
    assert server_info.version_info() == expected_version_info
예제 #19
0
    def __init__(self,
                 unresolved_address,
                 sock,
                 max_connection_lifetime,
                 *,
                 auth=None,
                 user_agent=None,
                 routing_context=None):
        self.unresolved_address = unresolved_address
        self.socket = sock
        self.server_info = ServerInfo(Address(sock.getpeername()),
                                      self.PROTOCOL_VERSION)
        self.outbox = Outbox()
        self.inbox = Inbox(self.socket, on_error=self._set_defunct)
        self.packer = Packer(self.outbox)
        self.unpacker = Unpacker(self.inbox)
        self.responses = deque()
        self._max_connection_lifetime = max_connection_lifetime
        self._creation_timestamp = perf_counter()
        self.supports_multiple_results = False
        self.supports_multiple_databases = False
        self._is_reset = True
        self.routing_context = routing_context

        # Determine the user agent
        if user_agent:
            self.user_agent = user_agent
        else:
            self.user_agent = get_user_agent()

        # Determine auth details
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j import Auth
            self.auth_dict = vars(Auth("basic", *auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise AuthError("Cannot determine auth details from %r" % auth)

        # Check for missing password
        try:
            credentials = self.auth_dict["credentials"]
        except KeyError:
            pass
        else:
            if credentials is None:
                raise AuthError("Password cannot be None")
예제 #20
0
    async def open(cls,
                   address,
                   *,
                   auth=None,
                   security=False,
                   protocol_version=None,
                   loop=None):
        """ Open a socket connection and perform protocol version
        negotiation, in order to construct and return a Bolt client
        instance for a supported Bolt protocol version.

        :param address: tuples of host and port, such as
                        ("127.0.0.1", 7687)
        :param auth:
        :param security:
        :param protocol_version:
        :param loop:
        :return: instance of a Bolt subclass
        :raise BoltConnectionError: if a connection could not be
            established
        :raise BoltConnectionLost: if an I/O error occurs on the
            underlying socket connection
        :raise BoltHandshakeError: if handshake completes without a
            successful negotiation
        :raise TypeError: if any of the arguments provided are passed
            as incompatible types
        :raise ValueError: if any of the arguments provided are passed
            with unsupported values
        """

        # Connect
        address = Address(address)
        reader, writer, security = await cls._connect(address, security, loop)

        try:

            # Handshake
            subclass = await cls._handshake(reader, writer, protocol_version)

            # Instantiation
            inst = subclass(reader, writer)
            inst.security = security
            assert hasattr(inst, "__ainit__")
            await inst.__ainit__(auth)
            return inst

        except BoltError:
            writer.write_eof()
            writer.close()
            raise
    async def open(cls, address, *, auth=None, loop=None, **config):
        """ Open a socket connection and perform protocol version
        negotiation, in order to construct and return a Bolt client
        instance for a supported Bolt protocol version.

        :param address: tuples of host and port, such as
                        ("127.0.0.1", 7687)
        :param auth:
        :param loop:
        :param config:
        :return: instance of a Bolt subclass
        :raise BoltConnectionError: if a connection could not be
            established
        :raise BoltConnectionLost: if an I/O error occurs on the
            underlying socket connection
        :raise BoltHandshakeError: if handshake completes without a
            successful negotiation
        :raise TypeError: if any of the arguments provided are passed
            as incompatible types
        :raise ValueError: if any of the arguments provided are passed
            with unsupported values
        """

        # Args
        address = Address(address)
        if loop is None:
            loop = get_event_loop()
        config = PoolConfig.consume(config)

        # Connect
        reader, writer = await cls._connect(address, loop, config)

        try:

            # Handshake
            subclass = await cls._handshake(reader, writer,
                                            config.protocol_version)

            # Instantiation
            obj = subclass(reader, writer)
            obj.secure = bool(config.secure)
            assert hasattr(obj, "__ainit__")
            await obj.__ainit__(auth)
            return obj

        except BoltError:
            writer.write_eof()
            writer.close()
            raise
예제 #22
0
    def __init__(self,
                 unresolved_address,
                 sock,
                 *,
                 auth=None,
                 protocol_version=None,
                 **config):
        self.config = PoolConfig.consume(config)
        self.protocol_version = protocol_version
        self.unresolved_address = unresolved_address
        self.socket = sock
        self.server = ServerInfo(Address(sock.getpeername()), protocol_version)
        self.outbox = Outbox()
        self.inbox = Inbox(BufferedSocket(self.socket, 32768),
                           on_error=self._set_defunct)
        self.packer = Packer(self.outbox)
        self.unpacker = Unpacker(self.inbox)
        self.responses = deque()
        self._max_connection_lifetime = self.config.max_age
        self._creation_timestamp = perf_counter()

        # Determine the user agent
        user_agent = self.config.user_agent
        if user_agent:
            self.user_agent = user_agent
        else:
            self.user_agent = get_user_agent()

        # Determine auth details
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            from neo4j import Auth
            self.auth_dict = vars(Auth("basic", *auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise AuthError("Cannot determine auth details from %r" % auth)

        # Check for missing password
        try:
            credentials = self.auth_dict["credentials"]
        except KeyError:
            pass
        else:
            if credentials is None:
                raise AuthError("Password cannot be None")
예제 #23
0
    def __init__(self, protocol_version, unresolved_address, sock, **config):
        self.protocol_version = protocol_version
        self.unresolved_address = unresolved_address
        self.socket = sock
        self.server = ServerInfo(Address(sock.getpeername()), protocol_version)
        self.outbox = Outbox()
        self.inbox = Inbox(BufferedSocket(self.socket, 32768),
                           on_error=self._set_defunct)
        self.packer = Packer(self.outbox)
        self.unpacker = Unpacker(self.inbox)
        self.responses = deque()
        self._max_connection_lifetime = config.get(
            "max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME)
        self._creation_timestamp = perf_counter()

        # Determine the user agent
        user_agent = config.get("user_agent")
        if user_agent:
            self.user_agent = user_agent
        else:
            self.user_agent = get_user_agent()

        # Determine auth details
        auth = config.get("auth")
        if not auth:
            self.auth_dict = {}
        elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
            self.auth_dict = vars(AuthToken("basic", *auth))
        else:
            try:
                self.auth_dict = vars(auth)
            except (KeyError, TypeError):
                raise AuthError("Cannot determine auth details from %r" % auth)

        # Check for missing password
        try:
            credentials = self.auth_dict["credentials"]
        except KeyError:
            pass
        else:
            if credentials is None:
                raise AuthError("Password cannot be None")

        # Pick up the server certificate, if any
        self.der_encoded_server_certificate = config.get(
            "der_encoded_server_certificate")
예제 #24
0
def test_address_parse_with_invalid_input(test_input, expected):
    # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_with_invalid_input
    with pytest.raises(expected):
        parsed = Address.parse(test_input)
예제 #25
0
def test_address_parse_with_ipv4(test_input, expected):
    # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_with_ipv4
    parsed = Address.parse(test_input)
    assert parsed == expected
예제 #26
0
 def local_address(self):
     return Address(self.__transport.get_extra_info("sockname"))
예제 #27
0
def test_address_should_parse_ipv6(test_input, expected):
    # python -m pytest tests/unit/test_addressing.py -s -k test_address_should_parse_ipv6
    parsed = Address.parse(test_input)
    assert parsed == expected
예제 #28
0
def test_address_parse_list_with_invalid_input(test_input, expected):
    # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_list_with_invalid_input
    with pytest.raises(TypeError):
        addresses = Address.parse_list(*test_input)
예제 #29
0
 def remote_address(self):
     return Address(self.__transport.get_extra_info("peername"))
예제 #30
0
def test_address_parse_list(test_input, expected):
    # python -m pytest tests/unit/test_addressing.py -s -k test_address_parse_list
    addresses = Address.parse_list(*test_input)
    assert len(addresses) == expected