def setUp(self):
     self.table = RoutingTable([("192.168.1.1", 7687),
                                ("192.168.1.2", 7687)],
                               [("192.168.1.3", 7687)], [], 0)
     self.new_table = RoutingTable([("127.0.0.1", 9001),
                                    ("127.0.0.1", 9002),
                                    ("127.0.0.1", 9003)],
                                   [("127.0.0.1", 9004),
                                    ("127.0.0.1", 9005)],
                                   [("127.0.0.1", 9006)], 300)
Пример #2
0
class RoutingTableUpdateTestCase(TestCase):
    def setUp(self):
        self.table = RoutingTable(
            database=DEFAULT_DATABASE,
            routers=[("192.168.1.1", 7687), ("192.168.1.2", 7687)],
            readers=[("192.168.1.3", 7687)],
            writers=[],
            ttl=0,
        )
        self.new_table = RoutingTable(
            database=DEFAULT_DATABASE,
            routers=[("127.0.0.1", 9001), ("127.0.0.1", 9002),
                     ("127.0.0.1", 9003)],
            readers=[("127.0.0.1", 9004), ("127.0.0.1", 9005)],
            writers=[("127.0.0.1", 9006)],
            ttl=300,
        )

    def test_update_should_replace_routers(self):
        self.table.update(self.new_table)
        assert self.table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002),
                                      ("127.0.0.1", 9003)}

    def test_update_should_replace_readers(self):
        self.table.update(self.new_table)
        assert self.table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)}

    def test_update_should_replace_writers(self):
        self.table.update(self.new_table)
        assert self.table.writers == {("127.0.0.1", 9006)}

    def test_update_should_replace_ttl(self):
        self.table.update(self.new_table)
        assert self.table.ttl == 300
class RoutingTableUpdateTestCase(TestCase):
    def setUp(self):
        self.table = RoutingTable([("192.168.1.1", 7687),
                                   ("192.168.1.2", 7687)],
                                  [("192.168.1.3", 7687)], [], 0)
        self.new_table = RoutingTable([("127.0.0.1", 9001),
                                       ("127.0.0.1", 9002),
                                       ("127.0.0.1", 9003)],
                                      [("127.0.0.1", 9004),
                                       ("127.0.0.1", 9005)],
                                      [("127.0.0.1", 9006)], 300)

    def test_update_should_replace_routers(self):
        self.table.update(self.new_table)
        assert self.table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002),
                                      ("127.0.0.1", 9003)}

    def test_update_should_replace_readers(self):
        self.table.update(self.new_table)
        assert self.table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)}

    def test_update_should_replace_writers(self):
        self.table.update(self.new_table)
        assert self.table.writers == {("127.0.0.1", 9006)}

    def test_update_should_replace_ttl(self):
        self.table.update(self.new_table)
        assert self.table.ttl == 300
Пример #4
0
 def setUp(self):
     self.table = RoutingTable(
         database=DEFAULT_DATABASE,
         routers=[("192.168.1.1", 7687), ("192.168.1.2", 7687)],
         readers=[("192.168.1.3", 7687)],
         writers=[],
         ttl=0,
     )
     self.new_table = RoutingTable(
         database=DEFAULT_DATABASE,
         routers=[("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)],
         readers=[("127.0.0.1", 9004), ("127.0.0.1", 9005)],
         writers=[("127.0.0.1", 9006)],
         ttl=300,
     )
 def __init__(self, loop, opener, config, addresses, routing_context):
     if loop is None:
         self._loop = get_event_loop()
     else:
         self._loop = loop
     self._opener = opener
     self._config = config
     self._pools = {}
     self._missing_writer = False
     self._refresh_lock = Lock(loop=self._loop)
     self._routing_context = routing_context
     self._max_size_per_host = config.max_size
     self._initial_routers = addresses
     self._routing_table = RoutingTable(addresses)
     self._activate_new_pools_in(self._routing_table)
 def test_should_return_routing_table_on_valid_record_with_extra_role(self):
     table = RoutingTable.parse_routing_info(VALID_ROUTING_RECORD_WITH_EXTRA_ROLE["servers"],
                                             VALID_ROUTING_RECORD_WITH_EXTRA_ROLE["ttl"])
     assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
     assert table.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)}
     assert table.writers == {('127.0.0.1', 9006)}
     assert table.ttl == 300
Пример #7
0
 async def get_routing_table(self, context=None):
     try:
         result = await self.run(
             "CALL dbms.cluster.routing.getRoutingTable($context)",
             {"context": dict(context or {})})
         record = await result.single()
         if not record:
             raise BoltRoutingError(
                 "Routing table call returned "
                 "no data", self.remote_address)
         assert isinstance(record, Record)
         servers = record["servers"]
         ttl = record["ttl"]
         log.debug("[#%04X] S: <ROUTING> servers=%r ttl=%r",
                   self.local_address.port_number, servers, ttl)
         return RoutingTable.parse_routing_info(servers, ttl)
     except BoltFailure as error:
         if error.title == "ProcedureNotFound":
             raise BoltRoutingError("Server does not support "
                                    "routing",
                                    self.remote_address) from error
         else:
             raise
     except ValueError as error:
         raise BoltRoutingError("Invalid routing table",
                                self.remote_address) from error
Пример #8
0
 def test_should_return_all_distinct_servers_in_routing_table(self):
     routing_table = {
         "ttl":
         300,
         "servers": [
             {
                 "role":
                 "ROUTE",
                 "addresses":
                 ["127.0.0.1:9001", "127.0.0.1:9002", "127.0.0.1:9003"]
             },
             {
                 "role": "READ",
                 "addresses": ["127.0.0.1:9001", "127.0.0.1:9005"]
             },
             {
                 "role": "WRITE",
                 "addresses": ["127.0.0.1:9002"]
             },
         ],
     }
     table = RoutingTable.parse_routing_info(
         database=DEFAULT_DATABASE,
         servers=routing_table["servers"],
         ttl=routing_table["ttl"],
     )
     assert table.servers() == {('127.0.0.1', 9001), ('127.0.0.1', 9002),
                                ('127.0.0.1', 9003), ('127.0.0.1', 9005)}
Пример #9
0
    def __init__(self, opener, pool_config, workspace_config, routing_context,
                 addresses):
        """

        :param opener:
        :param pool_config:
        :param workspace_config:
        :param routing_context: Dictionary with routing information
        :param addresses:
        """
        super(Neo4jPool, self).__init__(opener, pool_config, workspace_config)
        # Each database have a routing table, the default database is a special case.
        log.debug("[#0000]  C: <NEO4J POOL> routing addresses %r", addresses)
        self.init_address = addresses[0]
        self.routing_tables = {
            workspace_config.database:
            RoutingTable(database=workspace_config.database, routers=addresses)
        }
        self.routing_context = routing_context
        if self.routing_context is None:
            self.routing_context = {}
        elif "address" in self.routing_context:
            raise ConfigurationError(
                "The key 'address' is reserved for routing context.")
        self.routing_context["address"] = str(self.init_address)
        self.refresh_lock = Lock()
Пример #10
0
 def __init__(self, opener, pool_config, workspace_config, routing_context, addresses):
     super(Neo4jPool, self).__init__(opener, pool_config, workspace_config)
     # Each database have a routing table, the default database is a special case.
     log.debug("[#0000]  C: <NEO4J POOL> routing addresses %r", addresses)
     self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=addresses)}
     self.routing_context = routing_context
     # self.missing_writer = False
     self.refresh_lock = Lock()
Пример #11
0
 def test_should_be_fresh_after_update(self):
     table = RoutingTable.parse_routing_info(
         database=DEFAULT_DATABASE,
         servers=VALID_ROUTING_RECORD["servers"],
         ttl=VALID_ROUTING_RECORD["ttl"],
     )
     assert table.is_fresh(readonly=True)
     assert table.is_fresh(readonly=False)
Пример #12
0
 def test_should_become_stale_if_no_writers(self):
     table = RoutingTable.parse_routing_info(
         database=DEFAULT_DATABASE,
         servers=VALID_ROUTING_RECORD["servers"],
         ttl=VALID_ROUTING_RECORD["ttl"],
     )
     table.writers.clear()
     assert table.is_fresh(readonly=True)
     assert not table.is_fresh(readonly=False)
Пример #13
0
 def test_should_return_routing_table_on_valid_record(self):
     table = RoutingTable.parse_routing_info(
         database=DEFAULT_DATABASE,
         servers=VALID_ROUTING_RECORD["servers"],
         ttl=VALID_ROUTING_RECORD["ttl"],
     )
     assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
     assert table.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)}
     assert table.writers == {('127.0.0.1', 9006)}
     assert table.ttl == 300
    def __init__(self, opener, pool_config, workspace_config, routing_context, address):
        """

        :param opener:
        :param pool_config:
        :param workspace_config:
        :param routing_context: Dictionary with routing information
        :param addresses:
        """
        super(Neo4jPool, self).__init__(opener, pool_config, workspace_config)
        # Each database have a routing table, the default database is a special case.
        log.debug("[#0000]  C: <NEO4J POOL> routing address %r", address)
        self.address = address
        self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])}
        self.routing_context = routing_context
        self.refresh_lock = Lock()
Пример #15
0
    def fetch_routing_table(self, *, address, timeout, database):
        """ Fetch a routing table from a given router address.

        :param address: router address
        :param timeout: seconds
        :param database: the database name
        :type: str

        :return: a new RoutingTable instance or None if the given router is
                 currently unable to provide routing information

        :raise neo4j.exceptions.ServiceUnavailable: if no writers are available
        :raise neo4j._exceptions.BoltProtocolError: if the routing information received is unusable
        """
        new_routing_info = self.fetch_routing_info(address=address,
                                                   timeout=timeout,
                                                   database=database)
        if new_routing_info is None:
            return None
        elif not new_routing_info:
            raise BoltRoutingError("Invalid routing table", address)
        else:
            servers = new_routing_info[0]["servers"]
            ttl = new_routing_info[0]["ttl"]
            new_routing_table = RoutingTable.parse_routing_info(
                database=database, servers=servers, ttl=ttl)

        # Parse routing info and count the number of each type of server
        num_routers = len(new_routing_table.routers)
        num_readers = len(new_routing_table.readers)

        # num_writers = len(new_routing_table.writers)
        # If no writers are available. This likely indicates a temporary state,
        # such as leader switching, so we should not signal an error.

        # No routers
        if num_routers == 0:
            raise BoltRoutingError("No routing servers returned from server",
                                   address)

        # No readers
        if num_readers == 0:
            raise BoltRoutingError("No read servers returned from server",
                                   address)

        # At least one of each is fine, so return this table
        return new_routing_table
Пример #16
0
    def fetch_routing_table(self, address):
        """ Fetch a routing table from a given router address.

        :param address: router address
        :return: a new RoutingTable instance or None if the given router is
                 currently unable to provide routing information
        :raise ServiceUnavailable: if no writers are available
        :raise ProtocolError: if the routing information received is unusable
        """
        new_routing_info = self.fetch_routing_info(address)
        if new_routing_info is None:
            return None
        elif not new_routing_info:
            raise BoltRoutingError("Invalid routing table", address)
        else:
            servers = new_routing_info[0]["servers"]
            ttl = new_routing_info[0]["ttl"]
            new_routing_table = RoutingTable.parse_routing_info(servers, ttl)

        # Parse routing info and count the number of each type of server
        num_routers = len(new_routing_table.routers)
        num_readers = len(new_routing_table.readers)
        num_writers = len(new_routing_table.writers)

        # No writers are available. This likely indicates a temporary state,
        # such as leader switching, so we should not signal an error.
        # When no writers available, then we flag we are reading in absence of writer
        self.missing_writer = (num_writers == 0)

        # No routers
        if num_routers == 0:
            raise BoltRoutingError("No routing servers returned from server",
                                   address)

        # No readers
        if num_readers == 0:
            raise BoltRoutingError("No read servers returned from server",
                                   address)

        # At least one of each is fine, so return this table
        return new_routing_table
Пример #17
0
 def __init__(self, opener, pool_config, addresses, routing_context):
     super(Neo4jPool, self).__init__(opener, pool_config)
     self.routing_table = RoutingTable(addresses)
     self.routing_context = routing_context
     self.missing_writer = False
     self.refresh_lock = Lock()
Пример #18
0
class Neo4jPool(IOPool):
    """ Connection pool with routing table.
    """

    @classmethod
    def open(cls, *addresses, auth=None, routing_context=None, **config):
        pool_config = PoolConfig.consume(config)

        def opener(addr, timeout):
            return Bolt.open(addr, auth=auth, timeout=timeout, **pool_config)

        pool = cls(opener, pool_config, addresses, routing_context)
        try:
            pool.update_routing_table()
        except Exception:
            pool.close()
            raise
        else:
            return pool

    def __init__(self, opener, pool_config, addresses, routing_context):
        super(Neo4jPool, self).__init__(opener, pool_config)
        self.routing_table = RoutingTable(addresses)
        self.routing_context = routing_context
        self.missing_writer = False
        self.refresh_lock = Lock()

    def __repr__(self):
        return "<{} addresses={!r}>".format(self.__class__.__name__,
                                            self.routing_table.initial_routers)

    @property
    def initial_address(self):
        return self.routing_table.initial_routers[0]

    def fetch_routing_info(self, address, timeout=None):
        """ Fetch raw routing info from a given router address.

        :param address: router address
        :param timeout: seconds
        :return: list of routing records or
                 None if no connection could be established
        :raise ServiceUnavailable: if the server does not support routing or
                                   if routing support is broken
        """
        metadata = {}
        records = []

        def fail(md):
            if md.get("code") == "Neo.ClientError.Procedure.ProcedureNotFound":
                raise BoltRoutingError("Server does not support routing", address)
            else:
                raise BoltRoutingError("Routing support broken on server", address)

        try:
            with self._acquire(address, timeout) as cx:
                _, _, server_version = (cx.server.agent or "").partition("/")
                log.debug("[#%04X]  C: <ROUTING> query=%r", cx.local_port, self.routing_context or {})
                cx.run("CALL dbms.cluster.routing.getRoutingTable($context)",
                       {"context": self.routing_context}, on_success=metadata.update, on_failure=fail)
                cx.pull(on_success=metadata.update, on_records=records.extend)
                cx.send_all()
                cx.fetch_all()
                routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records]
                log.debug("[#%04X]  S: <ROUTING> info=%r", cx.local_port, routing_info)
            return routing_info
        except BoltRoutingError as error:
            raise ServiceUnavailable(*error.args)
        except ServiceUnavailable:
            self.deactivate(address)
            return None

    def fetch_routing_table(self, address, timeout=None):
        """ Fetch a routing table from a given router address.

        :param address: router address
        :param timeout: seconds
        :return: a new RoutingTable instance or None if the given router is
                 currently unable to provide routing information
        :raise ServiceUnavailable: if no writers are available
        :raise BoltProtocolError: if the routing information received is unusable
        """
        new_routing_info = self.fetch_routing_info(address, timeout)
        if new_routing_info is None:
            return None
        elif not new_routing_info:
            raise BoltRoutingError("Invalid routing table", address)
        else:
            servers = new_routing_info[0]["servers"]
            ttl = new_routing_info[0]["ttl"]
            new_routing_table = RoutingTable.parse_routing_info(servers, ttl)

        # Parse routing info and count the number of each type of server
        num_routers = len(new_routing_table.routers)
        num_readers = len(new_routing_table.readers)
        num_writers = len(new_routing_table.writers)

        # No writers are available. This likely indicates a temporary state,
        # such as leader switching, so we should not signal an error.
        # When no writers available, then we flag we are reading in absence of writer
        self.missing_writer = (num_writers == 0)

        # No routers
        if num_routers == 0:
            raise BoltRoutingError("No routing servers returned from server", address)

        # No readers
        if num_readers == 0:
            raise BoltRoutingError("No read servers returned from server", address)

        # At least one of each is fine, so return this table
        return new_routing_table

    def update_routing_table_from(self, *routers):
        """ Try to update routing tables with the given routers.

        :return: True if the routing table is successfully updated,
        otherwise False
        """
        log.debug("Attempting to update routing table from "
                  "{}".format(", ".join(map(repr, routers))))
        for router in routers:
            new_routing_table = self.fetch_routing_table(router)
            if new_routing_table is not None:
                self.routing_table.update(new_routing_table)
                log.debug("Successfully updated routing table from "
                          "{!r} ({!r})".format(router, self.routing_table))
                return True
        return False

    def update_routing_table(self):
        """ Update the routing table from the first router able to provide
        valid routing information.
        """
        # copied because it can be modified
        existing_routers = list(self.routing_table.routers)

        has_tried_initial_routers = False
        if self.missing_writer:
            has_tried_initial_routers = True
            if self.update_routing_table_from(self.initial_address):
                return

        if self.update_routing_table_from(*existing_routers):
            return

        if not has_tried_initial_routers and self.initial_address not in existing_routers:
            if self.update_routing_table_from(self.initial_address):
                return

        # None of the routers have been successful, so just fail
        log.error("Unable to retrieve routing information")
        raise ServiceUnavailable("Unable to retrieve routing information")

    def update_connection_pool(self):
        servers = self.routing_table.servers()
        for address in list(self.connections):
            if address not in servers:
                super(Neo4jPool, self).deactivate(address)

    def ensure_routing_table_is_fresh(self, access_mode):
        """ Update the routing table if stale.

        This method performs two freshness checks, before and after acquiring
        the refresh lock. If the routing table is already fresh on entry, the
        method exits immediately; otherwise, the refresh lock is acquired and
        the second freshness check that follows determines whether an update
        is still required.

        This method is thread-safe.

        :return: `True` if an update was required, `False` otherwise.
        """
        from neo4j.api import READ_ACCESS
        if self.routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)):
            return False
        with self.refresh_lock:
            if self.routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)):
                if access_mode == READ_ACCESS:
                    # if reader is fresh but writers is not fresh, then we are reading in absence of writer
                    self.missing_writer = not self.routing_table.is_fresh(readonly=False)
                return False
            self.update_routing_table()
            self.update_connection_pool()
            return True

    def _select_address(self, access_mode=None):
        from neo4j.api import READ_ACCESS
        """ Selects the address with the fewest in-use connections.
        """
        self.ensure_routing_table_is_fresh(access_mode)
        if access_mode == READ_ACCESS:
            addresses = self.routing_table.readers
        else:
            addresses = self.routing_table.writers
        addresses_by_usage = {}
        for address in addresses:
            addresses_by_usage.setdefault(self.in_use_connection_count(address), []).append(address)
        if not addresses_by_usage:
            if access_mode == READ_ACCESS:
                raise ReadServiceUnavailable("No read service currently available")
            else:
                raise WriteServiceUnavailable("No write service currently available")
        return choice(addresses_by_usage[min(addresses_by_usage)])

    def acquire(self, access_mode=None, timeout=None):
        from neo4j.api import check_access_mode
        access_mode = check_access_mode(access_mode)
        while True:
            try:
                address = self._select_address(access_mode)
            except (ReadServiceUnavailable, WriteServiceUnavailable) as err:
                raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err
            try:
                connection = self._acquire(address, timeout=timeout)  # should always be a resolved address
            except ServiceUnavailable:
                self.deactivate(address)
            else:
                return connection

    def deactivate(self, address):
        """ Deactivate an address from the connection pool,
        if present, remove from the routing table and also closing
        all idle connections to that address.
        """
        log.debug("[#0000]  C: <ROUTING> Deactivating address %r", address)
        # We use `discard` instead of `remove` here since the former
        # will not fail if the address has already been removed.
        self.routing_table.routers.discard(address)
        self.routing_table.readers.discard(address)
        self.routing_table.writers.discard(address)
        log.debug("[#0000]  C: <ROUTING> table=%r", self.routing_table)
        super(Neo4jPool, self).deactivate(address)

    def on_write_failure(self, address):
        """ Remove a writer address from the routing table, if present.
        """
        log.debug("[#0000]  C: <ROUTING> Removing writer %r", address)
        self.routing_table.writers.discard(address)
        log.debug("[#0000]  C: <ROUTING> table=%r", self.routing_table)
Пример #19
0
 def test_should_be_initially_stale(self):
     table = RoutingTable(database=DEFAULT_DATABASE)
     assert not table.is_fresh(readonly=True)
     assert not table.is_fresh(readonly=False)
 def test_should_become_stale_if_no_writers(self):
     table = RoutingTable.parse_routing_info(VALID_ROUTING_RECORD["servers"],
                                             VALID_ROUTING_RECORD["ttl"])
     table.writers.clear()
     assert table.is_fresh(readonly=True)
     assert not table.is_fresh(readonly=False)
 def test_should_become_stale_on_expiry(self):
     table = RoutingTable.parse_routing_info(VALID_ROUTING_RECORD["servers"],
                                             VALID_ROUTING_RECORD["ttl"])
     table.ttl = 0
     assert not table.is_fresh(readonly=True)
     assert not table.is_fresh(readonly=False)
 def test_should_be_fresh_after_update(self):
     table = RoutingTable.parse_routing_info(VALID_ROUTING_RECORD["servers"],
                                             VALID_ROUTING_RECORD["ttl"])
     assert table.is_fresh(readonly=True)
     assert table.is_fresh(readonly=False)
class Neo4jPool:
    """ Connection pool with routing table.
    """
    @classmethod
    async def open(cls,
                   *addresses,
                   auth=None,
                   routing_context=None,
                   loop=None,
                   **config):
        pool_config = PoolConfig.consume(config)

        def opener(addr):
            return Bolt.open(addr, auth=auth, **pool_config)

        obj = cls(loop, opener, config, addresses, routing_context)
        await obj._ensure_routing_table_is_fresh()
        return obj

    def __init__(self, loop, opener, config, addresses, routing_context):
        if loop is None:
            self._loop = get_event_loop()
        else:
            self._loop = loop
        self._opener = opener
        self._config = config
        self._pools = {}
        self._missing_writer = False
        self._refresh_lock = Lock(loop=self._loop)
        self._routing_context = routing_context
        self._max_size_per_host = config.max_size
        self._initial_routers = addresses
        self._routing_table = RoutingTable(addresses)
        self._activate_new_pools_in(self._routing_table)

    def _activate_new_pools_in(self, routing_table):
        """ Add pools for addresses that exist in the given routing
        table but which don't already have pools.
        """
        for address in routing_table.servers():
            if address not in self._pools:
                self._pools[address] = BoltPool(self._loop, self._opener,
                                                self._config, address)

    async def _deactivate_pools_not_in(self, routing_table):
        """ Deactivate any pools that aren't represented in the given
        routing table.
        """
        for address in self._pools:
            if address not in routing_table:
                await self._deactivate(address)

    async def _get_routing_table_from(self, *routers):
        """ Try to update routing tables with the given routers.

        :return: True if the routing table is successfully updated,
        otherwise False
        """
        log.debug("Attempting to update routing table from "
                  "{}".format(", ".join(map(repr, routers))))
        for router in routers:
            pool = self._pools[router]
            cx = await pool.acquire()
            try:
                new_routing_table = await cx.get_routing_table(
                    self._routing_context)
            except BoltError:
                await self._deactivate(router)
            else:
                num_routers = len(new_routing_table.routers)
                num_readers = len(new_routing_table.readers)
                num_writers = len(new_routing_table.writers)

                # No writers are available. This likely indicates a temporary state,
                # such as leader switching, so we should not signal an error.
                # When no writers available, then we flag we are reading in absence of writer
                self._missing_writer = (num_writers == 0)

                # No routers
                if num_routers == 0:
                    continue

                # No readers
                if num_readers == 0:
                    continue

                log.debug("Successfully updated routing table from "
                          "{!r} ({!r})".format(router, self._routing_table))
                return new_routing_table
            finally:
                await pool.release(cx)
        return None

    async def _get_routing_table(self):
        """ Update the routing table from the first router able to provide
        valid routing information.
        """
        # copied because it can be modified
        existing_routers = list(self._routing_table.routers)

        has_tried_initial_routers = False
        if self._missing_writer:
            has_tried_initial_routers = True
            rt = await self._get_routing_table_from(self._initial_routers)
            if rt:
                return rt

        rt = await self._get_routing_table_from(*existing_routers)
        if rt:
            return rt

        if not has_tried_initial_routers and self._initial_routers not in existing_routers:
            rt = await self._get_routing_table_from(self._initial_routers)
            if rt:
                return rt

        # None of the routers have been successful, so just fail
        log.error("Unable to retrieve routing information")
        raise Neo4jAvailabilityError("Unable to retrieve routing information")

    async def _ensure_routing_table_is_fresh(self, readonly=False):
        """ Update the routing table if stale.

        This method performs two freshness checks, before and after acquiring
        the refresh lock. If the routing table is already fresh on entry, the
        method exits immediately; otherwise, the refresh lock is acquired and
        the second freshness check that follows determines whether an update
        is still required.
        """
        if self._routing_table.is_fresh(readonly=readonly):
            return
        async with self._refresh_lock:
            if self._routing_table.is_fresh(readonly=readonly):
                if readonly:
                    # if reader is fresh but writers are not, then
                    # we are reading in absence of writer
                    self._missing_writer = not self._routing_table.is_fresh(
                        readonly=False)
            else:
                rt = await self._get_routing_table()
                self._activate_new_pools_in(rt)
                self._routing_table.update(rt)
                await self._deactivate_pools_not_in(rt)

    async def _select_pool(self, readonly=False):
        """ Selects the pool with the fewest in-use connections.
        """
        await self._ensure_routing_table_is_fresh(readonly=readonly)
        if readonly:
            addresses = self._routing_table.readers
        else:
            addresses = self._routing_table.writers
        pools = [
            pool for address, pool in self._pools.items()
            if address in addresses
        ]
        pools_by_usage = {}
        for pool in pools:
            pools_by_usage.setdefault(pool.in_use, []).append(pool)
        if not pools_by_usage:
            raise Neo4jAvailabilityError(
                "No {} service currently "
                "available".format("read" if readonly else "write"))
        return choice(pools_by_usage[min(pools_by_usage)])

    async def acquire(self, *, readonly=False, force_reset=False):
        """ Acquire a connection to a server that can satisfy a set of parameters.

        :param readonly: true if a readonly connection is required,
            otherwise false
        :param force_reset:
        """
        while True:
            pool = await self._select_pool(readonly=readonly)
            try:
                cx = await pool.acquire(force_reset=force_reset)
            except BoltError:
                await self._deactivate(pool.address)
            else:
                if not readonly:
                    # If we're not acquiring a connection as
                    # readonly, then intercept NotALeader and
                    # ForbiddenOnReadOnlyDatabase errors to
                    # invalidate the routing table.
                    from neo4j.errors import (
                        NotALeader,
                        ForbiddenOnReadOnlyDatabase,
                    )

                    def handler(failure):
                        """ Invalidate the routing table before raising the failure.
                        """
                        log.debug(
                            "[#0000]  C: <ROUTING> Invalidating routing table")
                        self._routing_table.ttl = 0
                        raise failure

                    cx.set_failure_handler(NotALeader, handler)
                    cx.set_failure_handler(ForbiddenOnReadOnlyDatabase,
                                           handler)
                return cx

    async def release(self, connection, *, force_reset=False):
        """ Release a connection back into the pool.
        This method is thread safe.
        """
        for pool in self._pools.values():
            try:
                await pool.release(connection, force_reset=force_reset)
            except ValueError:
                pass
            else:
                # Unhook any custom error handling and exit.
                from neo4j.errors import (
                    NotALeader,
                    ForbiddenOnReadOnlyDatabase,
                )
                connection.del_failure_handler(NotALeader)
                connection.del_failure_handler(ForbiddenOnReadOnlyDatabase)
                break
        else:
            raise ValueError("Connection does not belong to this pool")

    async def _deactivate(self, address):
        """ Deactivate an address from the connection pool,
        if present, remove from the routing table and also closing
        all idle connections to that address.
        """
        log.debug("[#0000]  C: <ROUTING> Deactivating address %r", address)
        # We use `discard` instead of `remove` here since the former
        # will not fail if the address has already been removed.
        self._routing_table.routers.discard(address)
        self._routing_table.readers.discard(address)
        self._routing_table.writers.discard(address)
        log.debug("[#0000]  C: <ROUTING> table=%r", self._routing_table)
        try:
            pool = self._pools.pop(address)
        except KeyError:
            pass  # assume the address has already been removed
        else:
            pool.max_size = 0
            await pool.prune()

    async def close(self, force=False):
        """ Close all connections and empty the pool. If forced, in-use
        connections will be closed immediately; if not, they will
        remain open until released.
        """
        pools = dict(self._pools)
        self._pools.clear()
        for address, pool in pools.items():
            if force:
                await pool.close()
            else:
                pool.max_size = 0
                await pool.prune()
Пример #24
0
 def create_routing_table(self, database):
     if database not in self.routing_tables:
         self.routing_tables[database] = RoutingTable(database=database, routers=self.get_default_database_initial_router_addresses())
 def test_should_be_initially_stale(self):
     table = RoutingTable()
     assert not table.is_fresh(readonly=True)
     assert not table.is_fresh(readonly=False)