class KazooTreeCacheTests(KazooAdaptiveHandlerTestCase): def setUp(self): super(KazooTreeCacheTests, self).setUp() self._event_queue = self.client.handler.queue_impl() self._error_queue = self.client.handler.queue_impl() self.path = None self.cache = None def tearDown(self): if not self._error_queue.empty(): try: raise self._error_queue.get() except FakeException: pass if self.cache is not None: self.cache.close() self.cache = None super(KazooTreeCacheTests, self).tearDown() def make_cache(self): if self.cache is None: self.path = '/' + uuid.uuid4().hex self.cache = TreeCache(self.client, self.path) self.cache.listen(lambda event: self._event_queue.put(event)) self.cache.listen_fault(lambda error: self._error_queue.put(error)) self.cache.start() return self.cache def wait_cache(self, expect=None, since=None, timeout=10): started = since is None while True: event = self._event_queue.get(timeout=timeout) if started: if expect is not None: eq_(event.event_type, expect) return event if event.event_type == since: started = True if expect is None: return def spy_client(self, method_name): method = getattr(self.client, method_name) return patch.object(self.client, method_name, wraps=method) def _wait_gc(self): # trigger switching on some coroutine handlers self.client.handler.sleep_func(0.1) completion_queue = getattr(self.handler, 'completion_queue', None) if completion_queue is not None: while not self.client.handler.completion_queue.empty(): self.client.handler.sleep_func(0.1) for gen in range(3): gc.collect(gen) def count_tree_node(self): # inspect GC and count tree nodes for checking memory leak for retry in range(10): result = set() for _ in range(5): self._wait_gc() result.add(count_refs_by_type('TreeNode')) if len(result) == 1: return list(result)[0] raise RuntimeError('could not count refs exactly') def test_start(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) stat = self.client.exists(self.path) eq_(stat.version, 0) eq_(self.cache._state, TreeCache.STATE_STARTED) eq_(self.cache._root._state, TreeNode.STATE_LIVE) @raises(KazooException) def test_start_started(self): self.make_cache() self.cache.start() @raises(KazooException) def test_start_closed(self): self.make_cache() self.cache.start() self.cache.close() self.cache.start() def test_close(self): eq_(self.count_tree_node(), 0) self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', makepath=True) for _ in range(3): self.wait_cache(TreeEvent.NODE_ADDED) # setup stub watchers which are outside of tree cache stub_data_watcher = Mock(spec=lambda event: None) stub_child_watcher = Mock(spec=lambda event: None) self.client.get(self.path + '/foo', stub_data_watcher) self.client.get_children(self.path + '/foo', stub_child_watcher) # watchers inside tree cache should be here root_path = self.client.chroot + self.path eq_(len(self.client._data_watchers[root_path + '/foo']), 2) eq_(len(self.client._data_watchers[root_path + '/foo/bar']), 1) eq_(len(self.client._data_watchers[root_path + '/foo/bar/baz']), 1) eq_(len(self.client._child_watchers[root_path + '/foo']), 2) eq_(len(self.client._child_watchers[root_path + '/foo/bar']), 1) eq_(len(self.client._child_watchers[root_path + '/foo/bar/baz']), 1) self.cache.close() # nothing should be published since tree closed ok_(self._event_queue.empty()) # tree should be empty eq_(self.cache._root._children, {}) eq_(self.cache._root._data, None) eq_(self.cache._state, TreeCache.STATE_CLOSED) # node state should not be changed assert_not_equal(self.cache._root._state, TreeNode.STATE_DEAD) # watchers should be reset eq_(len(self.client._data_watchers[root_path + '/foo']), 1) eq_(len(self.client._data_watchers[root_path + '/foo/bar']), 0) eq_(len(self.client._data_watchers[root_path + '/foo/bar/baz']), 0) eq_(len(self.client._child_watchers[root_path + '/foo']), 1) eq_(len(self.client._child_watchers[root_path + '/foo/bar']), 0) eq_(len(self.client._child_watchers[root_path + '/foo/bar/baz']), 0) # outside watchers should not be deleted eq_( list(self.client._data_watchers[root_path + '/foo'])[0], stub_data_watcher) eq_( list(self.client._child_watchers[root_path + '/foo'])[0], stub_child_watcher) # should not be any leaked memory (tree node) here self.cache = None eq_(self.count_tree_node(), 0) def test_delete_operation(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) eq_(self.count_tree_node(), 1) self.client.create(self.path + '/foo/bar/baz', makepath=True) for _ in range(3): self.wait_cache(TreeEvent.NODE_ADDED) self.client.delete(self.path + '/foo', recursive=True) for _ in range(3): self.wait_cache(TreeEvent.NODE_REMOVED) # tree should be empty eq_(self.cache._root._children, {}) # watchers should be reset root_path = self.client.chroot + self.path eq_(self.client._data_watchers[root_path + '/foo'], set()) eq_(self.client._data_watchers[root_path + '/foo/bar'], set()) eq_(self.client._data_watchers[root_path + '/foo/bar/baz'], set()) eq_(self.client._child_watchers[root_path + '/foo'], set()) eq_(self.client._child_watchers[root_path + '/foo/bar'], set()) eq_(self.client._child_watchers[root_path + '/foo/bar/baz'], set()) # should not be any leaked memory (tree node) here eq_(self.count_tree_node(), 1) def test_children_operation(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/test_children', b'test_children_1') event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_type, TreeEvent.NODE_ADDED) eq_(event.event_data.path, self.path + '/test_children') eq_(event.event_data.data, b'test_children_1') eq_(event.event_data.stat.version, 0) self.client.set(self.path + '/test_children', b'test_children_2') event = self.wait_cache(TreeEvent.NODE_UPDATED) eq_(event.event_type, TreeEvent.NODE_UPDATED) eq_(event.event_data.path, self.path + '/test_children') eq_(event.event_data.data, b'test_children_2') eq_(event.event_data.stat.version, 1) self.client.delete(self.path + '/test_children') event = self.wait_cache(TreeEvent.NODE_REMOVED) eq_(event.event_type, TreeEvent.NODE_REMOVED) eq_(event.event_data.path, self.path + '/test_children') eq_(event.event_data.data, b'test_children_2') eq_(event.event_data.stat.version, 1) def test_subtree_operation(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', makepath=True) for relative_path in ('/foo', '/foo/bar', '/foo/bar/baz'): event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_type, TreeEvent.NODE_ADDED) eq_(event.event_data.path, self.path + relative_path) eq_(event.event_data.data, b'') eq_(event.event_data.stat.version, 0) self.client.delete(self.path + '/foo', recursive=True) for relative_path in ('/foo/bar/baz', '/foo/bar', '/foo'): event = self.wait_cache(TreeEvent.NODE_REMOVED) eq_(event.event_type, TreeEvent.NODE_REMOVED) eq_(event.event_data.path, self.path + relative_path) def test_get_data(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_data(self.path).data, b'') eq_(cache.get_data(self.path).stat.version, 0) eq_(cache.get_data(self.path + '/foo').data, b'') eq_(cache.get_data(self.path + '/foo').stat.version, 0) eq_(cache.get_data(self.path + '/foo/bar').data, b'') eq_(cache.get_data(self.path + '/foo/bar').stat.version, 0) eq_(cache.get_data(self.path + '/foo/bar/baz').data, b'@') eq_(cache.get_data(self.path + '/foo/bar/baz').stat.version, 0) def test_get_children(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_children(self.path + '/foo/bar/baz'), frozenset()) eq_(cache.get_children(self.path + '/foo/bar'), frozenset(['baz'])) eq_(cache.get_children(self.path + '/foo'), frozenset(['bar'])) eq_(cache.get_children(self.path), frozenset(['foo'])) @raises(ValueError) def test_get_data_out_of_tree(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.cache.get_data('/out_of_tree') @raises(ValueError) def test_get_children_out_of_tree(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.cache.get_children('/out_of_tree') def test_get_data_no_node(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_data(self.path + '/non_exists'), None) def test_get_children_no_node(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_children(self.path + '/non_exists'), None) def test_session_reconnected(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo') event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_data.path, self.path + '/foo') with self.spy_client('get_async') as get_data: with self.spy_client('get_children_async') as get_children: # session suspended self.lose_connection(self.client.handler.event_object) self.wait_cache(TreeEvent.CONNECTION_SUSPENDED) # There are a serial refreshing operation here. But NODE_ADDED # events will not be raised because the zxid of nodes are the # same during reconnecting. # connection restore self.wait_cache(TreeEvent.CONNECTION_RECONNECTED) # wait for outstanding operations while self.cache._outstanding_ops > 0: self.client.handler.sleep_func(0.1) # inspect in-memory nodes _node_root = self.cache._root _node_foo = self.cache._root._children['foo'] # make sure that all nodes are refreshed get_data.assert_has_calls([ call(self.path, watch=_node_root._process_watch), call(self.path + '/foo', watch=_node_foo._process_watch), ], any_order=True) get_children.assert_has_calls([ call(self.path, watch=_node_root._process_watch), call(self.path + '/foo', watch=_node_foo._process_watch), ], any_order=True) def test_root_recreated(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) # remove root node self.client.delete(self.path) event = self.wait_cache(TreeEvent.NODE_REMOVED) eq_(event.event_type, TreeEvent.NODE_REMOVED) eq_(event.event_data.data, b'') eq_(event.event_data.path, self.path) eq_(event.event_data.stat.version, 0) # re-create root node self.client.ensure_path(self.path) event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_type, TreeEvent.NODE_ADDED) eq_(event.event_data.data, b'') eq_(event.event_data.path, self.path) eq_(event.event_data.stat.version, 0) self.assertTrue( self.cache._outstanding_ops >= 0, 'unexpected outstanding ops %r' % self.cache._outstanding_ops) def test_exception_handler(self): error_value = FakeException() error_handler = Mock() with patch.object(TreeNode, 'on_deleted') as on_deleted: on_deleted.side_effect = [error_value] self.make_cache() self.cache.listen_fault(error_handler) self.cache.close() error_handler.assert_called_once_with(error_value) def test_exception_suppressed(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) # stoke up ConnectionClosedError self.client.stop() self.client.close() self.client.handler.start() # keep the async completion self.wait_cache(since=TreeEvent.CONNECTION_LOST) with patch.object(TreeNode, 'on_created') as on_created: self.cache._root._call_client('exists', '/') self.cache._root._call_client('get', '/') self.cache._root._call_client('get_children', '/') self.wait_cache(since=TreeEvent.INITIALIZED) on_created.assert_not_called() eq_(self.cache._outstanding_ops, 0)
class KazooTreeCacheTests(KazooTestCase): def setUp(self): super(KazooTreeCacheTests, self).setUp() self._event_queue = self.client.handler.queue_impl() self._error_queue = self.client.handler.queue_impl() self.path = None self.cache = None def tearDown(self): super(KazooTreeCacheTests, self).tearDown() if not self._error_queue.empty(): try: raise self._error_queue.get() except FakeException: pass def make_cache(self): if self.cache is None: self.path = '/' + uuid.uuid4().hex self.cache = TreeCache(self.client, self.path) self.cache.listen(lambda event: self._event_queue.put(event)) self.cache.listen_fault(lambda error: self._error_queue.put(error)) self.cache.start() return self.cache def wait_cache(self, expect=None, since=None, timeout=10): started = since is None while True: event = self._event_queue.get(timeout=timeout) if started: if expect is not None: eq_(event.event_type, expect) return event if event.event_type == since: started = True if expect is None: return def spy_client(self, method_name): method = getattr(self.client, method_name) return patch.object(self.client, method_name, wraps=method) def test_start(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) stat = self.client.exists(self.path) eq_(stat.version, 0) eq_(self.cache._state, TreeCache.STATE_STARTED) eq_(self.cache._root._state, TreeNode.STATE_LIVE) @raises(KazooException) def test_start_started(self): self.make_cache() self.cache.start() @raises(KazooException) def test_start_closed(self): self.make_cache() self.cache.start() self.cache.close() self.cache.start() def test_close(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', makepath=True) for _ in range(3): self.wait_cache(TreeEvent.NODE_ADDED) self.cache.close() # nothing should be published since tree closed ok_(self._event_queue.empty()) # tree should be empty eq_(self.cache._root._children, {}) eq_(self.cache._root._data, None) eq_(self.cache._state, TreeCache.STATE_CLOSED) # node state should not be changed assert_not_equal(self.cache._root._state, TreeNode.STATE_DEAD) def test_children_operation(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/test_children', b'test_children_1') event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_type, TreeEvent.NODE_ADDED) eq_(event.event_data.path, self.path + '/test_children') eq_(event.event_data.data, b'test_children_1') eq_(event.event_data.stat.version, 0) self.client.set(self.path + '/test_children', b'test_children_2') event = self.wait_cache(TreeEvent.NODE_UPDATED) eq_(event.event_type, TreeEvent.NODE_UPDATED) eq_(event.event_data.path, self.path + '/test_children') eq_(event.event_data.data, b'test_children_2') eq_(event.event_data.stat.version, 1) self.client.delete(self.path + '/test_children') event = self.wait_cache(TreeEvent.NODE_REMOVED) eq_(event.event_type, TreeEvent.NODE_REMOVED) eq_(event.event_data.path, self.path + '/test_children') eq_(event.event_data.data, b'test_children_2') eq_(event.event_data.stat.version, 1) def test_subtree_operation(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', makepath=True) for relative_path in ('/foo', '/foo/bar', '/foo/bar/baz'): event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_type, TreeEvent.NODE_ADDED) eq_(event.event_data.path, self.path + relative_path) eq_(event.event_data.data, b'') eq_(event.event_data.stat.version, 0) self.client.delete(self.path + '/foo', recursive=True) for relative_path in ('/foo/bar/baz', '/foo/bar', '/foo'): event = self.wait_cache(TreeEvent.NODE_REMOVED) eq_(event.event_type, TreeEvent.NODE_REMOVED) eq_(event.event_data.path, self.path + relative_path) def test_get_data(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_data(self.path).data, b'') eq_(cache.get_data(self.path).stat.version, 0) eq_(cache.get_data(self.path + '/foo').data, b'') eq_(cache.get_data(self.path + '/foo').stat.version, 0) eq_(cache.get_data(self.path + '/foo/bar').data, b'') eq_(cache.get_data(self.path + '/foo/bar').stat.version, 0) eq_(cache.get_data(self.path + '/foo/bar/baz').data, b'@') eq_(cache.get_data(self.path + '/foo/bar/baz').stat.version, 0) def test_get_children(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_children(self.path + '/foo/bar/baz'), frozenset()) eq_(cache.get_children(self.path + '/foo/bar'), frozenset(['baz'])) eq_(cache.get_children(self.path + '/foo'), frozenset(['bar'])) eq_(cache.get_children(self.path), frozenset(['foo'])) @raises(ValueError) def test_get_data_out_of_tree(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.cache.get_data('/out_of_tree') @raises(ValueError) def test_get_children_out_of_tree(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.cache.get_children('/out_of_tree') def test_get_data_no_node(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_data(self.path + '/non_exists'), None) def test_get_children_no_node(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_children(self.path + '/non_exists'), None) def test_session_reconnected(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo') event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_data.path, self.path + '/foo') with self.spy_client('get_async') as get_data: with self.spy_client('get_children_async') as get_children: # session suspended self.lose_connection(self.client.handler.event_object) self.wait_cache(TreeEvent.CONNECTION_SUSPENDED) # There are a serial refreshing operation here. But NODE_ADDED # events will not be raised because the zxid of nodes are the # same during reconnecting. # connection restore self.wait_cache(TreeEvent.CONNECTION_RECONNECTED) # wait for outstanding operations while self.cache._outstanding_ops > 0: self.client.handler.sleep_func(0.1) # inspect in-memory nodes _node_root = self.cache._root _node_foo = self.cache._root._children['foo'] # make sure that all nodes are refreshed get_data.assert_has_calls([ call(self.path, watch=_node_root._process_watch), call(self.path + '/foo', watch=_node_foo._process_watch), ], any_order=True) get_children.assert_has_calls([ call(self.path, watch=_node_root._process_watch), call(self.path + '/foo', watch=_node_foo._process_watch), ], any_order=True) def test_root_recreated(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) # remove root node self.client.delete(self.path) event = self.wait_cache(TreeEvent.NODE_REMOVED) eq_(event.event_type, TreeEvent.NODE_REMOVED) eq_(event.event_data.data, b'') eq_(event.event_data.path, self.path) eq_(event.event_data.stat.version, 0) # re-create root node self.client.ensure_path(self.path) event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_type, TreeEvent.NODE_ADDED) eq_(event.event_data.data, b'') eq_(event.event_data.path, self.path) eq_(event.event_data.stat.version, 0) self.assertTrue( self.cache._outstanding_ops >= 0, 'unexpected outstanding ops %r' % self.cache._outstanding_ops) def test_exception_handler(self): error_value = FakeException() error_handler = Mock() with patch.object(TreeNode, 'on_deleted') as on_deleted: on_deleted.side_effect = [error_value] self.make_cache() self.cache.listen_fault(error_handler) self.cache.close() error_handler.assert_called_once_with(error_value)
class ZookeeperClusterManager(ClusterManager): """ A cluster manager that manages one cluster's state and configurations with a Zookeeper ensemble via kazoo. Below is the structure of the znodes: /needlestack /<CLUSTER_NAME_1> /live_nodes /<HOSTPORT_1> /<HOSTPORT_2> /<HOSTPORT_3> /<HOSTPORT_4> ... /collections /<COLLECTION_NAME_1> /shards /<SHARD_NAME_1> /replicas /<HOSTPORT_2> /<HOSTPORT_4> /<SHARD_NAME_2> /replicas /<HOSTPORT_1> /<HOSTPORT_3> /<COLLECTION_NAME_2> ... """ cluster_name: str hostport: str zk: KazooClient cache: TreeCache def __init__(self, cluster_name: str, hostport: str, hosts: List[str], zookeeper_root: str): self.cluster_name = cluster_name self.hostport = hostport self.zookeeper_root = zookeeper_root self.zk = KazooClient(hosts=hosts) self.zk.add_listener(self.zk_listener) self.cache = TreeCache(self.zk, self.base_znode) @property def base_znode(self): return f"{self.zookeeper_root}/{self.cluster_name}" @property def live_nodes_znode(self): return f"{self.base_znode}/live_nodes" @property def this_node_znode(self): return f"{self.base_znode}/live_nodes/{self.hostport}" @property def collections_znode(self): return f"{self.base_znode}/collections" def collection_znode(self, collection_name: str) -> str: return f"{self.collections_znode}/{collection_name}" def shard_znode(self, collection_name: str, shard_name: str = None) -> str: znode = f"{self.collections_znode}/{collection_name}/shards" if shard_name: znode += "/" + shard_name return znode def replica_znode(self, collection_name: str, shard_name: str, hostport: str = None) -> str: shard_znode = self.shard_znode(collection_name, shard_name) znode = f"{shard_znode}/replicas" if hostport: znode += "/" + hostport return znode def startup(self): self.zk.start() self.cache.start() signal.signal(signal.SIGINT, self.signal_listener) signal.signal(signal.SIGTERM, self.signal_listener) self.zk.ensure_path(self.live_nodes_znode) self.zk.ensure_path(self.collections_znode) def shutdown(self): self.cache.close() self.zk.stop() self.zk.close() def cleanup(self): logger.info(f"Removing ZNodes via cleanup") transaction = self.zk.transaction() for collection in self.list_local_collections(): for shard in collection.shards: for replica in shard.replicas: znode = self.replica_znode(collection.name, shard.name, replica.hostport) transaction.delete(znode) self.commit_transaction(transaction) def register_merger(self): pass def register_searcher(self): try: retrier = KazooRetry(max_tries=5, delay=1, backoff=2, max_delay=20) retrier(self.zk.create, self.this_node_znode, ephemeral=True, makepath=True) logger.info(f"Created ephemeral ZNode {self.this_node_znode}") except kazoo.retry.RetryFailedError: logger.error( f"Max retries reached for creating ephemeral ZNode {self.this_node_znode}" ) except kazoo.retry.InterruptedError: logger.error( f"Retries interrupted for creating ephemeral ZNode {self.this_node_znode}" ) def set_state(self, state, collection_name=None, shard_name=None, hostport=None): transaction = self.zk.transaction() collections = [collection_name] if collection_name else None for collection in self._list_collections(collections, hostport=hostport, load_replica=True): logger.info( f"Set {collection.name}/shards ZNodes to {collections_pb2.Replica.State.Name(state)}" ) for shard in collection.shards: for replica in shard.replicas: znode = self.replica_znode(collection.name, shard.name, replica.node.hostport) replica.state = state transaction.set_data(znode, replica.SerializeToString()) return self.commit_transaction(transaction) def set_local_state(self, state, collection_name=None, shard_name=None): return self.set_state(state, collection_name, shard_name, self.hostport) def signal_listener(self, signum, frame): self.shutdown() def zk_listener(self, state): if state == KazooState.LOST: logger.warn("Connection to Zookeeper lost") elif state == KazooState.SUSPENDED: logger.warn("Connection to Zookeeper disconnected") else: logger.info("Connection to Zookeeper established") def add_collections(self, collections): """Configure a list of collections into Zookeeper """ transaction = self.zk.transaction() for collection in collections: collection_copy = deepcopy(collection) collection_copy.ClearField("shards") collection_znode = self.collection_znode(collection.name) transaction.create(collection_znode, collection_copy.SerializeToString()) transaction.create(self.shard_znode(collection.name)) for shard in collection.shards: shard_copy = deepcopy(shard) shard_copy.ClearField("replicas") shard_znode = self.shard_znode(collection.name, shard.name) transaction.create(shard_znode, shard_copy.SerializeToString()) transaction.create( self.replica_znode(collection.name, shard.name)) for replica in shard.replicas: replica_copy = deepcopy(replica) replica_copy.state = collections_pb2.Replica.BOOTING replica_znode = self.replica_znode(collection.name, shard.name, replica.node.hostport) transaction.create(replica_znode, replica_copy.SerializeToString()) if self.commit_transaction(transaction): return collections else: return [] def delete_collections(self, collection_names): transaction = self.zk.transaction() for collection_name in collection_names: shards_znode = self.shard_znode(collection_name) for shard_name in self.zk.get_children(shards_znode): replicas_znode = self.replica_znode(collection_name, shard_name) for replica_name in self.zk.get_children(replicas_znode): replica_znode = self.replica_znode(collection_name, shard_name, replica_name) transaction.delete(replica_znode) transaction.delete(replicas_znode) transaction.delete( self.shard_znode(collection_name, shard_name)) transaction.delete(shards_znode) transaction.delete(self.collection_znode(collection_name)) if self.commit_transaction(transaction): return collection_names else: return [] def list_nodes(self): live_nodes = self.zk.get_children(self.live_nodes_znode) nodes = [collections_pb2.Node(hostport=node) for node in live_nodes] return nodes def list_collections(self, collection_names=None, include_state=True): return self._list_collections(collection_names, load_replica=include_state) def list_local_collections(self, include_state=True): return self._list_collections(hostport=self.hostport, load_replica=include_state) def _list_collections( self, collection_names: Optional[List[str]] = None, hostport: Optional[str] = None, load_replica: Optional[bool] = True, ) -> List[collections_pb2.Collection]: collections = [] collection_names = collection_names or self.zk.get_children( self.collections_znode) for collection_name in collection_names: shards = [] shards_znode = self.shard_znode(collection_name) for shard_name in self.zk.get_children(shards_znode): replicas = [] replicas_znode = self.replica_znode(collection_name, shard_name) for replica_hostport in self.zk.get_children(replicas_znode): if hostport == replica_hostport or hostport is None: replica_znode = self.replica_znode( collection_name, shard_name, replica_hostport) if load_replica: replica_data, _ = self.zk.get(replica_znode) replica_proto = collections_pb2.Replica.FromString( replica_data) else: replica_proto = collections_pb2.Replica() replicas.append(replica_proto) if replicas: shard_znode = self.shard_znode(collection_name, shard_name) shard_data, _ = self.zk.get(shard_znode) shard_proto = collections_pb2.Shard.FromString(shard_data) shard_proto.replicas.extend(replicas) shards.append(shard_proto) if shards: collection_znode = self.collection_znode(collection_name) collection_data, _ = self.zk.get(collection_znode) collection_proto = collections_pb2.Collection.FromString( collection_data) collection_proto.shards.extend(shards) collections.append(collection_proto) return collections def get_searchers(self, collection_name, shard_names=None): if not shard_names: shards_znode = self.shard_znode(collection_name) shard_names = self.cache.get_children(shards_znode, []) shard_hostports = [] for shard_name in shard_names: hostports = self._get_searchers_for_shard(collection_name, shard_name, active=True) if hostports: shard_hostports.append((shard_name, hostports)) else: logger.error( f"No active Searcher node for {collection_name}/{shard_name}." ) return shard_hostports def _get_searchers_for_shard(self, collection_name: str, shard_name: str, active: bool = True) -> List[str]: replicas_znode = self.replica_znode(collection_name, shard_name) hostports = self.cache.get_children(replicas_znode, []) if active: active_hostports = [] for hostport in hostports: replica_znode = self.replica_znode(collection_name, shard_name, hostport) node = self.cache.get_data(replica_znode) if node: replica = collections_pb2.Replica.FromString(node.data) if replica.state == collections_pb2.Replica.ACTIVE: active_hostports.append(hostport) hostports = active_hostports return hostports def commit_transaction( self, transaction: kazoo.client.TransactionRequest) -> bool: """Commit a transaction and log the first exception after rollbacks""" for result, operation in zip(transaction.commit(), transaction.operations): if isinstance(result, kazoo.exceptions.RolledBackError): continue elif isinstance(result, Exception): logger.error( f"{result.__class__.__name__} in Kazoo transaction: {operation}" ) return False return True
class KazooTreeCacheTests(KazooTestCase): def setUp(self): super(KazooTreeCacheTests, self).setUp() self._event_queue = self.client.handler.queue_impl() self._error_queue = self.client.handler.queue_impl() self.path = None self.cache = None def tearDown(self): super(KazooTreeCacheTests, self).tearDown() if not self._error_queue.empty(): try: raise self._error_queue.get() except FakeException: pass def make_cache(self): if self.cache is None: self.path = '/' + uuid.uuid4().hex self.cache = TreeCache(self.client, self.path) self.cache.listen(lambda event: self._event_queue.put(event)) self.cache.listen_fault(lambda error: self._error_queue.put(error)) self.cache.start() return self.cache def wait_cache(self, expect=None, since=None, timeout=10): started = since is None while True: event = self._event_queue.get(timeout=timeout) if started: if expect is not None: eq_(event.event_type, expect) return event if event.event_type == since: started = True if expect is None: return def spy_client(self, method_name): method = getattr(self.client, method_name) return patch.object(self.client, method_name, wraps=method) def test_start(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) stat = self.client.exists(self.path) eq_(stat.version, 0) eq_(self.cache._state, TreeCache.STATE_STARTED) eq_(self.cache._root._state, TreeNode.STATE_LIVE) @raises(KazooException) def test_start_started(self): self.make_cache() self.cache.start() @raises(KazooException) def test_start_closed(self): self.make_cache() self.cache.start() self.cache.close() self.cache.start() def test_close(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', makepath=True) for _ in range(3): self.wait_cache(TreeEvent.NODE_ADDED) self.cache.close() # nothing should be published since tree closed ok_(self._event_queue.empty()) # tree should be empty eq_(self.cache._root._children, {}) eq_(self.cache._root._data, None) eq_(self.cache._state, TreeCache.STATE_CLOSED) # node state should not be changed assert_not_equal(self.cache._root._state, TreeNode.STATE_DEAD) def test_children_operation(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/test_children', b'test_children_1') event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_type, TreeEvent.NODE_ADDED) eq_(event.event_data.path, self.path + '/test_children') eq_(event.event_data.data, b'test_children_1') eq_(event.event_data.stat.version, 0) self.client.set(self.path + '/test_children', b'test_children_2') event = self.wait_cache(TreeEvent.NODE_UPDATED) eq_(event.event_type, TreeEvent.NODE_UPDATED) eq_(event.event_data.path, self.path + '/test_children') eq_(event.event_data.data, b'test_children_2') eq_(event.event_data.stat.version, 1) self.client.delete(self.path + '/test_children') event = self.wait_cache(TreeEvent.NODE_REMOVED) eq_(event.event_type, TreeEvent.NODE_REMOVED) eq_(event.event_data.path, self.path + '/test_children') eq_(event.event_data.data, b'test_children_2') eq_(event.event_data.stat.version, 1) def test_subtree_operation(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', makepath=True) for relative_path in ('/foo', '/foo/bar', '/foo/bar/baz'): event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_type, TreeEvent.NODE_ADDED) eq_(event.event_data.path, self.path + relative_path) eq_(event.event_data.data, b'') eq_(event.event_data.stat.version, 0) self.client.delete(self.path + '/foo', recursive=True) for relative_path in ('/foo/bar/baz', '/foo/bar', '/foo'): event = self.wait_cache(TreeEvent.NODE_REMOVED) eq_(event.event_type, TreeEvent.NODE_REMOVED) eq_(event.event_data.path, self.path + relative_path) def test_get_data(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_data(self.path).data, b'') eq_(cache.get_data(self.path).stat.version, 0) eq_(cache.get_data(self.path + '/foo').data, b'') eq_(cache.get_data(self.path + '/foo').stat.version, 0) eq_(cache.get_data(self.path + '/foo/bar').data, b'') eq_(cache.get_data(self.path + '/foo/bar').stat.version, 0) eq_(cache.get_data(self.path + '/foo/bar/baz').data, b'@') eq_(cache.get_data(self.path + '/foo/bar/baz').stat.version, 0) def test_get_children(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) self.wait_cache(TreeEvent.NODE_ADDED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_children(self.path + '/foo/bar/baz'), frozenset()) eq_(cache.get_children(self.path + '/foo/bar'), frozenset(['baz'])) eq_(cache.get_children(self.path + '/foo'), frozenset(['bar'])) eq_(cache.get_children(self.path), frozenset(['foo'])) @raises(ValueError) def test_get_data_out_of_tree(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.cache.get_data('/out_of_tree') @raises(ValueError) def test_get_children_out_of_tree(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.cache.get_children('/out_of_tree') def test_get_data_no_node(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_data(self.path + '/non_exists'), None) def test_get_children_no_node(self): cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, '_client'): # disable any remote operation eq_(cache.get_children(self.path + '/non_exists'), None) def test_session_reconnected(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + '/foo') event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_data.path, self.path + '/foo') with self.spy_client('get_async') as get_data: with self.spy_client('get_children_async') as get_children: # session suspended self.lose_connection(self.client.handler.event_object) self.wait_cache(TreeEvent.CONNECTION_SUSPENDED) # There are a serial refreshing operation here. But NODE_ADDED # events will not be raised because the zxid of nodes are the # same during reconnecting. # connection restore self.wait_cache(TreeEvent.CONNECTION_RECONNECTED) # wait for outstanding operations while self.cache._outstanding_ops > 0: self.client.handler.sleep_func(0.1) # inspect in-memory nodes _node_root = self.cache._root _node_foo = self.cache._root._children['foo'] # make sure that all nodes are refreshed get_data.assert_has_calls([ call(self.path, watch=_node_root._process_watch), call(self.path + '/foo', watch=_node_foo._process_watch), ], any_order=True) get_children.assert_has_calls([ call(self.path, watch=_node_root._process_watch), call(self.path + '/foo', watch=_node_foo._process_watch), ], any_order=True) def test_root_recreated(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) # remove root node self.client.delete(self.path) event = self.wait_cache(TreeEvent.NODE_REMOVED) eq_(event.event_type, TreeEvent.NODE_REMOVED) eq_(event.event_data.data, b'') eq_(event.event_data.path, self.path) eq_(event.event_data.stat.version, 0) # re-create root node self.client.ensure_path(self.path) event = self.wait_cache(TreeEvent.NODE_ADDED) eq_(event.event_type, TreeEvent.NODE_ADDED) eq_(event.event_data.data, b'') eq_(event.event_data.path, self.path) eq_(event.event_data.stat.version, 0) self.assertTrue( self.cache._outstanding_ops >= 0, 'unexpected outstanding ops %r' % self.cache._outstanding_ops) def test_exception_handler(self): error_value = FakeException() error_handler = Mock() with patch.object(TreeNode, 'on_deleted') as on_deleted: on_deleted.side_effect = [error_value] self.make_cache() self.cache.listen_fault(error_handler) self.cache.close() error_handler.assert_called_once_with(error_value) def test_exception_suppressed(self): self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) # stoke up ConnectionClosedError self.client.stop() self.client.close() self.client.handler.start() # keep the async completion self.wait_cache(since=TreeEvent.CONNECTION_LOST) with patch.object(TreeNode, 'on_created') as on_created: self.cache._root._call_client('exists', '/') self.cache._root._call_client('get', '/') self.cache._root._call_client('get_children', '/') self.wait_cache(since=TreeEvent.INITIALIZED) on_created.assert_not_called() eq_(self.cache._outstanding_ops, 0)