class RoutingTableUpdateTestCase(TestCase): def setUp(self): self.table = RoutingTable( database=DEFAULT_DATABASE, routers=[("192.168.1.1", 7687), ("192.168.1.2", 7687)], readers=[("192.168.1.3", 7687)], writers=[], ttl=0, ) self.new_table = RoutingTable( database=DEFAULT_DATABASE, routers=[("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)], readers=[("127.0.0.1", 9004), ("127.0.0.1", 9005)], writers=[("127.0.0.1", 9006)], ttl=300, ) def test_update_should_replace_routers(self): self.table.update(self.new_table) assert self.table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)} def test_update_should_replace_readers(self): self.table.update(self.new_table) assert self.table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} def test_update_should_replace_writers(self): self.table.update(self.new_table) assert self.table.writers == {("127.0.0.1", 9006)} def test_update_should_replace_ttl(self): self.table.update(self.new_table) assert self.table.ttl == 300
class RoutingTableUpdateTestCase(TestCase): def setUp(self): self.table = RoutingTable([("192.168.1.1", 7687), ("192.168.1.2", 7687)], [("192.168.1.3", 7687)], [], 0) self.new_table = RoutingTable([("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)], [("127.0.0.1", 9004), ("127.0.0.1", 9005)], [("127.0.0.1", 9006)], 300) def test_update_should_replace_routers(self): self.table.update(self.new_table) assert self.table.routers == {("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003)} def test_update_should_replace_readers(self): self.table.update(self.new_table) assert self.table.readers == {("127.0.0.1", 9004), ("127.0.0.1", 9005)} def test_update_should_replace_writers(self): self.table.update(self.new_table) assert self.table.writers == {("127.0.0.1", 9006)} def test_update_should_replace_ttl(self): self.table.update(self.new_table) assert self.table.ttl == 300
class Neo4jPool(IOPool): """ Connection pool with routing table. """ @classmethod def open(cls, *addresses, auth=None, routing_context=None, **config): pool_config = PoolConfig.consume(config) def opener(addr, timeout): return Bolt.open(addr, auth=auth, timeout=timeout, **pool_config) pool = cls(opener, pool_config, addresses, routing_context) try: pool.update_routing_table() except Exception: pool.close() raise else: return pool def __init__(self, opener, pool_config, addresses, routing_context): super(Neo4jPool, self).__init__(opener, pool_config) self.routing_table = RoutingTable(addresses) self.routing_context = routing_context self.missing_writer = False self.refresh_lock = Lock() def __repr__(self): return "<{} addresses={!r}>".format(self.__class__.__name__, self.routing_table.initial_routers) @property def initial_address(self): return self.routing_table.initial_routers[0] def fetch_routing_info(self, address, timeout=None): """ Fetch raw routing info from a given router address. :param address: router address :param timeout: seconds :return: list of routing records or None if no connection could be established :raise ServiceUnavailable: if the server does not support routing or if routing support is broken """ metadata = {} records = [] def fail(md): if md.get("code") == "Neo.ClientError.Procedure.ProcedureNotFound": raise BoltRoutingError("Server does not support routing", address) else: raise BoltRoutingError("Routing support broken on server", address) try: with self._acquire(address, timeout) as cx: _, _, server_version = (cx.server.agent or "").partition("/") log.debug("[#%04X] C: <ROUTING> query=%r", cx.local_port, self.routing_context or {}) cx.run("CALL dbms.cluster.routing.getRoutingTable($context)", {"context": self.routing_context}, on_success=metadata.update, on_failure=fail) cx.pull(on_success=metadata.update, on_records=records.extend) cx.send_all() cx.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] log.debug("[#%04X] S: <ROUTING> info=%r", cx.local_port, routing_info) return routing_info except BoltRoutingError as error: raise ServiceUnavailable(*error.args) except ServiceUnavailable: self.deactivate(address) return None def fetch_routing_table(self, address, timeout=None): """ Fetch a routing table from a given router address. :param address: router address :param timeout: seconds :return: a new RoutingTable instance or None if the given router is currently unable to provide routing information :raise ServiceUnavailable: if no writers are available :raise BoltProtocolError: if the routing information received is unusable """ new_routing_info = self.fetch_routing_info(address, timeout) if new_routing_info is None: return None elif not new_routing_info: raise BoltRoutingError("Invalid routing table", address) else: servers = new_routing_info[0]["servers"] ttl = new_routing_info[0]["ttl"] new_routing_table = RoutingTable.parse_routing_info(servers, ttl) # Parse routing info and count the number of each type of server num_routers = len(new_routing_table.routers) num_readers = len(new_routing_table.readers) num_writers = len(new_routing_table.writers) # No writers are available. This likely indicates a temporary state, # such as leader switching, so we should not signal an error. # When no writers available, then we flag we are reading in absence of writer self.missing_writer = (num_writers == 0) # No routers if num_routers == 0: raise BoltRoutingError("No routing servers returned from server", address) # No readers if num_readers == 0: raise BoltRoutingError("No read servers returned from server", address) # At least one of each is fine, so return this table return new_routing_table def update_routing_table_from(self, *routers): """ Try to update routing tables with the given routers. :return: True if the routing table is successfully updated, otherwise False """ log.debug("Attempting to update routing table from " "{}".format(", ".join(map(repr, routers)))) for router in routers: new_routing_table = self.fetch_routing_table(router) if new_routing_table is not None: self.routing_table.update(new_routing_table) log.debug("Successfully updated routing table from " "{!r} ({!r})".format(router, self.routing_table)) return True return False def update_routing_table(self): """ Update the routing table from the first router able to provide valid routing information. """ # copied because it can be modified existing_routers = list(self.routing_table.routers) has_tried_initial_routers = False if self.missing_writer: has_tried_initial_routers = True if self.update_routing_table_from(self.initial_address): return if self.update_routing_table_from(*existing_routers): return if not has_tried_initial_routers and self.initial_address not in existing_routers: if self.update_routing_table_from(self.initial_address): return # None of the routers have been successful, so just fail log.error("Unable to retrieve routing information") raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self): servers = self.routing_table.servers() for address in list(self.connections): if address not in servers: super(Neo4jPool, self).deactivate(address) def ensure_routing_table_is_fresh(self, access_mode): """ Update the routing table if stale. This method performs two freshness checks, before and after acquiring the refresh lock. If the routing table is already fresh on entry, the method exits immediately; otherwise, the refresh lock is acquired and the second freshness check that follows determines whether an update is still required. This method is thread-safe. :return: `True` if an update was required, `False` otherwise. """ from neo4j.api import READ_ACCESS if self.routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): return False with self.refresh_lock: if self.routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): if access_mode == READ_ACCESS: # if reader is fresh but writers is not fresh, then we are reading in absence of writer self.missing_writer = not self.routing_table.is_fresh(readonly=False) return False self.update_routing_table() self.update_connection_pool() return True def _select_address(self, access_mode=None): from neo4j.api import READ_ACCESS """ Selects the address with the fewest in-use connections. """ self.ensure_routing_table_is_fresh(access_mode) if access_mode == READ_ACCESS: addresses = self.routing_table.readers else: addresses = self.routing_table.writers addresses_by_usage = {} for address in addresses: addresses_by_usage.setdefault(self.in_use_connection_count(address), []).append(address) if not addresses_by_usage: if access_mode == READ_ACCESS: raise ReadServiceUnavailable("No read service currently available") else: raise WriteServiceUnavailable("No write service currently available") return choice(addresses_by_usage[min(addresses_by_usage)]) def acquire(self, access_mode=None, timeout=None): from neo4j.api import check_access_mode access_mode = check_access_mode(access_mode) while True: try: address = self._select_address(access_mode) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err try: connection = self._acquire(address, timeout=timeout) # should always be a resolved address except ServiceUnavailable: self.deactivate(address) else: return connection def deactivate(self, address): """ Deactivate an address from the connection pool, if present, remove from the routing table and also closing all idle connections to that address. """ log.debug("[#0000] C: <ROUTING> Deactivating address %r", address) # We use `discard` instead of `remove` here since the former # will not fail if the address has already been removed. self.routing_table.routers.discard(address) self.routing_table.readers.discard(address) self.routing_table.writers.discard(address) log.debug("[#0000] C: <ROUTING> table=%r", self.routing_table) super(Neo4jPool, self).deactivate(address) def on_write_failure(self, address): """ Remove a writer address from the routing table, if present. """ log.debug("[#0000] C: <ROUTING> Removing writer %r", address) self.routing_table.writers.discard(address) log.debug("[#0000] C: <ROUTING> table=%r", self.routing_table)
class Neo4jPool: """ Connection pool with routing table. """ @classmethod async def open(cls, *addresses, auth=None, routing_context=None, loop=None, **config): pool_config = PoolConfig.consume(config) def opener(addr): return Bolt.open(addr, auth=auth, **pool_config) obj = cls(loop, opener, config, addresses, routing_context) await obj._ensure_routing_table_is_fresh() return obj def __init__(self, loop, opener, config, addresses, routing_context): if loop is None: self._loop = get_event_loop() else: self._loop = loop self._opener = opener self._config = config self._pools = {} self._missing_writer = False self._refresh_lock = Lock(loop=self._loop) self._routing_context = routing_context self._max_size_per_host = config.max_size self._initial_routers = addresses self._routing_table = RoutingTable(addresses) self._activate_new_pools_in(self._routing_table) def _activate_new_pools_in(self, routing_table): """ Add pools for addresses that exist in the given routing table but which don't already have pools. """ for address in routing_table.servers(): if address not in self._pools: self._pools[address] = BoltPool(self._loop, self._opener, self._config, address) async def _deactivate_pools_not_in(self, routing_table): """ Deactivate any pools that aren't represented in the given routing table. """ for address in self._pools: if address not in routing_table: await self._deactivate(address) async def _get_routing_table_from(self, *routers): """ Try to update routing tables with the given routers. :return: True if the routing table is successfully updated, otherwise False """ log.debug("Attempting to update routing table from " "{}".format(", ".join(map(repr, routers)))) for router in routers: pool = self._pools[router] cx = await pool.acquire() try: new_routing_table = await cx.get_routing_table( self._routing_context) except BoltError: await self._deactivate(router) else: num_routers = len(new_routing_table.routers) num_readers = len(new_routing_table.readers) num_writers = len(new_routing_table.writers) # No writers are available. This likely indicates a temporary state, # such as leader switching, so we should not signal an error. # When no writers available, then we flag we are reading in absence of writer self._missing_writer = (num_writers == 0) # No routers if num_routers == 0: continue # No readers if num_readers == 0: continue log.debug("Successfully updated routing table from " "{!r} ({!r})".format(router, self._routing_table)) return new_routing_table finally: await pool.release(cx) return None async def _get_routing_table(self): """ Update the routing table from the first router able to provide valid routing information. """ # copied because it can be modified existing_routers = list(self._routing_table.routers) has_tried_initial_routers = False if self._missing_writer: has_tried_initial_routers = True rt = await self._get_routing_table_from(self._initial_routers) if rt: return rt rt = await self._get_routing_table_from(*existing_routers) if rt: return rt if not has_tried_initial_routers and self._initial_routers not in existing_routers: rt = await self._get_routing_table_from(self._initial_routers) if rt: return rt # None of the routers have been successful, so just fail log.error("Unable to retrieve routing information") raise Neo4jAvailabilityError("Unable to retrieve routing information") async def _ensure_routing_table_is_fresh(self, readonly=False): """ Update the routing table if stale. This method performs two freshness checks, before and after acquiring the refresh lock. If the routing table is already fresh on entry, the method exits immediately; otherwise, the refresh lock is acquired and the second freshness check that follows determines whether an update is still required. """ if self._routing_table.is_fresh(readonly=readonly): return async with self._refresh_lock: if self._routing_table.is_fresh(readonly=readonly): if readonly: # if reader is fresh but writers are not, then # we are reading in absence of writer self._missing_writer = not self._routing_table.is_fresh( readonly=False) else: rt = await self._get_routing_table() self._activate_new_pools_in(rt) self._routing_table.update(rt) await self._deactivate_pools_not_in(rt) async def _select_pool(self, readonly=False): """ Selects the pool with the fewest in-use connections. """ await self._ensure_routing_table_is_fresh(readonly=readonly) if readonly: addresses = self._routing_table.readers else: addresses = self._routing_table.writers pools = [ pool for address, pool in self._pools.items() if address in addresses ] pools_by_usage = {} for pool in pools: pools_by_usage.setdefault(pool.in_use, []).append(pool) if not pools_by_usage: raise Neo4jAvailabilityError( "No {} service currently " "available".format("read" if readonly else "write")) return choice(pools_by_usage[min(pools_by_usage)]) async def acquire(self, *, readonly=False, force_reset=False): """ Acquire a connection to a server that can satisfy a set of parameters. :param readonly: true if a readonly connection is required, otherwise false :param force_reset: """ while True: pool = await self._select_pool(readonly=readonly) try: cx = await pool.acquire(force_reset=force_reset) except BoltError: await self._deactivate(pool.address) else: if not readonly: # If we're not acquiring a connection as # readonly, then intercept NotALeader and # ForbiddenOnReadOnlyDatabase errors to # invalidate the routing table. from neo4j.errors import ( NotALeader, ForbiddenOnReadOnlyDatabase, ) def handler(failure): """ Invalidate the routing table before raising the failure. """ log.debug( "[#0000] C: <ROUTING> Invalidating routing table") self._routing_table.ttl = 0 raise failure cx.set_failure_handler(NotALeader, handler) cx.set_failure_handler(ForbiddenOnReadOnlyDatabase, handler) return cx async def release(self, connection, *, force_reset=False): """ Release a connection back into the pool. This method is thread safe. """ for pool in self._pools.values(): try: await pool.release(connection, force_reset=force_reset) except ValueError: pass else: # Unhook any custom error handling and exit. from neo4j.errors import ( NotALeader, ForbiddenOnReadOnlyDatabase, ) connection.del_failure_handler(NotALeader) connection.del_failure_handler(ForbiddenOnReadOnlyDatabase) break else: raise ValueError("Connection does not belong to this pool") async def _deactivate(self, address): """ Deactivate an address from the connection pool, if present, remove from the routing table and also closing all idle connections to that address. """ log.debug("[#0000] C: <ROUTING> Deactivating address %r", address) # We use `discard` instead of `remove` here since the former # will not fail if the address has already been removed. self._routing_table.routers.discard(address) self._routing_table.readers.discard(address) self._routing_table.writers.discard(address) log.debug("[#0000] C: <ROUTING> table=%r", self._routing_table) try: pool = self._pools.pop(address) except KeyError: pass # assume the address has already been removed else: pool.max_size = 0 await pool.prune() async def close(self, force=False): """ Close all connections and empty the pool. If forced, in-use connections will be closed immediately; if not, they will remain open until released. """ pools = dict(self._pools) self._pools.clear() for address, pool in pools.items(): if force: await pool.close() else: pool.max_size = 0 await pool.prune()