Exemplo n.º 1
0
class KazooTreeCacheTests(KazooAdaptiveHandlerTestCase):
    def setUp(self):
        super(KazooTreeCacheTests, self).setUp()
        self._event_queue = self.client.handler.queue_impl()
        self._error_queue = self.client.handler.queue_impl()
        self.path = None
        self.cache = None

    def tearDown(self):
        if not self._error_queue.empty():
            try:
                raise self._error_queue.get()
            except FakeException:
                pass
        if self.cache is not None:
            self.cache.close()
            self.cache = None
        super(KazooTreeCacheTests, self).tearDown()

    def make_cache(self):
        if self.cache is None:
            self.path = '/' + uuid.uuid4().hex
            self.cache = TreeCache(self.client, self.path)
            self.cache.listen(lambda event: self._event_queue.put(event))
            self.cache.listen_fault(lambda error: self._error_queue.put(error))
            self.cache.start()
        return self.cache

    def wait_cache(self, expect=None, since=None, timeout=10):
        started = since is None
        while True:
            event = self._event_queue.get(timeout=timeout)
            if started:
                if expect is not None:
                    eq_(event.event_type, expect)
                return event
            if event.event_type == since:
                started = True
                if expect is None:
                    return

    def spy_client(self, method_name):
        method = getattr(self.client, method_name)
        return patch.object(self.client, method_name, wraps=method)

    def _wait_gc(self):
        # trigger switching on some coroutine handlers
        self.client.handler.sleep_func(0.1)

        completion_queue = getattr(self.handler, 'completion_queue', None)
        if completion_queue is not None:
            while not self.client.handler.completion_queue.empty():
                self.client.handler.sleep_func(0.1)

        for gen in range(3):
            gc.collect(gen)

    def count_tree_node(self):
        # inspect GC and count tree nodes for checking memory leak
        for retry in range(10):
            result = set()
            for _ in range(5):
                self._wait_gc()
                result.add(count_refs_by_type('TreeNode'))
            if len(result) == 1:
                return list(result)[0]
        raise RuntimeError('could not count refs exactly')

    def test_start(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        stat = self.client.exists(self.path)
        eq_(stat.version, 0)

        eq_(self.cache._state, TreeCache.STATE_STARTED)
        eq_(self.cache._root._state, TreeNode.STATE_LIVE)

    @raises(KazooException)
    def test_start_started(self):
        self.make_cache()
        self.cache.start()

    @raises(KazooException)
    def test_start_closed(self):
        self.make_cache()
        self.cache.start()
        self.cache.close()
        self.cache.start()

    def test_close(self):
        eq_(self.count_tree_node(), 0)

        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.client.create(self.path + '/foo/bar/baz', makepath=True)
        for _ in range(3):
            self.wait_cache(TreeEvent.NODE_ADDED)

        # setup stub watchers which are outside of tree cache
        stub_data_watcher = Mock(spec=lambda event: None)
        stub_child_watcher = Mock(spec=lambda event: None)
        self.client.get(self.path + '/foo', stub_data_watcher)
        self.client.get_children(self.path + '/foo', stub_child_watcher)

        # watchers inside tree cache should be here
        root_path = self.client.chroot + self.path
        eq_(len(self.client._data_watchers[root_path + '/foo']), 2)
        eq_(len(self.client._data_watchers[root_path + '/foo/bar']), 1)
        eq_(len(self.client._data_watchers[root_path + '/foo/bar/baz']), 1)
        eq_(len(self.client._child_watchers[root_path + '/foo']), 2)
        eq_(len(self.client._child_watchers[root_path + '/foo/bar']), 1)
        eq_(len(self.client._child_watchers[root_path + '/foo/bar/baz']), 1)

        self.cache.close()

        # nothing should be published since tree closed
        ok_(self._event_queue.empty())

        # tree should be empty
        eq_(self.cache._root._children, {})
        eq_(self.cache._root._data, None)
        eq_(self.cache._state, TreeCache.STATE_CLOSED)

        # node state should not be changed
        assert_not_equal(self.cache._root._state, TreeNode.STATE_DEAD)

        # watchers should be reset
        eq_(len(self.client._data_watchers[root_path + '/foo']), 1)
        eq_(len(self.client._data_watchers[root_path + '/foo/bar']), 0)
        eq_(len(self.client._data_watchers[root_path + '/foo/bar/baz']), 0)
        eq_(len(self.client._child_watchers[root_path + '/foo']), 1)
        eq_(len(self.client._child_watchers[root_path + '/foo/bar']), 0)
        eq_(len(self.client._child_watchers[root_path + '/foo/bar/baz']), 0)

        # outside watchers should not be deleted
        eq_(
            list(self.client._data_watchers[root_path + '/foo'])[0],
            stub_data_watcher)
        eq_(
            list(self.client._child_watchers[root_path + '/foo'])[0],
            stub_child_watcher)

        # should not be any leaked memory (tree node) here
        self.cache = None
        eq_(self.count_tree_node(), 0)

    def test_delete_operation(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        eq_(self.count_tree_node(), 1)

        self.client.create(self.path + '/foo/bar/baz', makepath=True)
        for _ in range(3):
            self.wait_cache(TreeEvent.NODE_ADDED)

        self.client.delete(self.path + '/foo', recursive=True)
        for _ in range(3):
            self.wait_cache(TreeEvent.NODE_REMOVED)

        # tree should be empty
        eq_(self.cache._root._children, {})

        # watchers should be reset
        root_path = self.client.chroot + self.path
        eq_(self.client._data_watchers[root_path + '/foo'], set())
        eq_(self.client._data_watchers[root_path + '/foo/bar'], set())
        eq_(self.client._data_watchers[root_path + '/foo/bar/baz'], set())
        eq_(self.client._child_watchers[root_path + '/foo'], set())
        eq_(self.client._child_watchers[root_path + '/foo/bar'], set())
        eq_(self.client._child_watchers[root_path + '/foo/bar/baz'], set())

        # should not be any leaked memory (tree node) here
        eq_(self.count_tree_node(), 1)

    def test_children_operation(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        self.client.create(self.path + '/test_children', b'test_children_1')
        event = self.wait_cache(TreeEvent.NODE_ADDED)
        eq_(event.event_type, TreeEvent.NODE_ADDED)
        eq_(event.event_data.path, self.path + '/test_children')
        eq_(event.event_data.data, b'test_children_1')
        eq_(event.event_data.stat.version, 0)

        self.client.set(self.path + '/test_children', b'test_children_2')
        event = self.wait_cache(TreeEvent.NODE_UPDATED)
        eq_(event.event_type, TreeEvent.NODE_UPDATED)
        eq_(event.event_data.path, self.path + '/test_children')
        eq_(event.event_data.data, b'test_children_2')
        eq_(event.event_data.stat.version, 1)

        self.client.delete(self.path + '/test_children')
        event = self.wait_cache(TreeEvent.NODE_REMOVED)
        eq_(event.event_type, TreeEvent.NODE_REMOVED)
        eq_(event.event_data.path, self.path + '/test_children')
        eq_(event.event_data.data, b'test_children_2')
        eq_(event.event_data.stat.version, 1)

    def test_subtree_operation(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        self.client.create(self.path + '/foo/bar/baz', makepath=True)
        for relative_path in ('/foo', '/foo/bar', '/foo/bar/baz'):
            event = self.wait_cache(TreeEvent.NODE_ADDED)
            eq_(event.event_type, TreeEvent.NODE_ADDED)
            eq_(event.event_data.path, self.path + relative_path)
            eq_(event.event_data.data, b'')
            eq_(event.event_data.stat.version, 0)

        self.client.delete(self.path + '/foo', recursive=True)
        for relative_path in ('/foo/bar/baz', '/foo/bar', '/foo'):
            event = self.wait_cache(TreeEvent.NODE_REMOVED)
            eq_(event.event_type, TreeEvent.NODE_REMOVED)
            eq_(event.event_data.path, self.path + relative_path)

    def test_get_data(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_data(self.path).data, b'')
            eq_(cache.get_data(self.path).stat.version, 0)

            eq_(cache.get_data(self.path + '/foo').data, b'')
            eq_(cache.get_data(self.path + '/foo').stat.version, 0)

            eq_(cache.get_data(self.path + '/foo/bar').data, b'')
            eq_(cache.get_data(self.path + '/foo/bar').stat.version, 0)

            eq_(cache.get_data(self.path + '/foo/bar/baz').data, b'@')
            eq_(cache.get_data(self.path + '/foo/bar/baz').stat.version, 0)

    def test_get_children(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_children(self.path + '/foo/bar/baz'), frozenset())
            eq_(cache.get_children(self.path + '/foo/bar'), frozenset(['baz']))
            eq_(cache.get_children(self.path + '/foo'), frozenset(['bar']))
            eq_(cache.get_children(self.path), frozenset(['foo']))

    @raises(ValueError)
    def test_get_data_out_of_tree(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.cache.get_data('/out_of_tree')

    @raises(ValueError)
    def test_get_children_out_of_tree(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.cache.get_children('/out_of_tree')

    def test_get_data_no_node(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_data(self.path + '/non_exists'), None)

    def test_get_children_no_node(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_children(self.path + '/non_exists'), None)

    def test_session_reconnected(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        self.client.create(self.path + '/foo')
        event = self.wait_cache(TreeEvent.NODE_ADDED)
        eq_(event.event_data.path, self.path + '/foo')

        with self.spy_client('get_async') as get_data:
            with self.spy_client('get_children_async') as get_children:
                # session suspended
                self.lose_connection(self.client.handler.event_object)
                self.wait_cache(TreeEvent.CONNECTION_SUSPENDED)

                # There are a serial refreshing operation here. But NODE_ADDED
                # events will not be raised because the zxid of nodes are the
                # same during reconnecting.

                # connection restore
                self.wait_cache(TreeEvent.CONNECTION_RECONNECTED)

                # wait for outstanding operations
                while self.cache._outstanding_ops > 0:
                    self.client.handler.sleep_func(0.1)

                # inspect in-memory nodes
                _node_root = self.cache._root
                _node_foo = self.cache._root._children['foo']

                # make sure that all nodes are refreshed
                get_data.assert_has_calls([
                    call(self.path, watch=_node_root._process_watch),
                    call(self.path + '/foo', watch=_node_foo._process_watch),
                ],
                                          any_order=True)
                get_children.assert_has_calls([
                    call(self.path, watch=_node_root._process_watch),
                    call(self.path + '/foo', watch=_node_foo._process_watch),
                ],
                                              any_order=True)

    def test_root_recreated(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        # remove root node
        self.client.delete(self.path)
        event = self.wait_cache(TreeEvent.NODE_REMOVED)
        eq_(event.event_type, TreeEvent.NODE_REMOVED)
        eq_(event.event_data.data, b'')
        eq_(event.event_data.path, self.path)
        eq_(event.event_data.stat.version, 0)

        # re-create root node
        self.client.ensure_path(self.path)
        event = self.wait_cache(TreeEvent.NODE_ADDED)
        eq_(event.event_type, TreeEvent.NODE_ADDED)
        eq_(event.event_data.data, b'')
        eq_(event.event_data.path, self.path)
        eq_(event.event_data.stat.version, 0)

        self.assertTrue(
            self.cache._outstanding_ops >= 0,
            'unexpected outstanding ops %r' % self.cache._outstanding_ops)

    def test_exception_handler(self):
        error_value = FakeException()
        error_handler = Mock()

        with patch.object(TreeNode, 'on_deleted') as on_deleted:
            on_deleted.side_effect = [error_value]

            self.make_cache()
            self.cache.listen_fault(error_handler)

            self.cache.close()
            error_handler.assert_called_once_with(error_value)

    def test_exception_suppressed(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        # stoke up ConnectionClosedError
        self.client.stop()
        self.client.close()
        self.client.handler.start()  # keep the async completion
        self.wait_cache(since=TreeEvent.CONNECTION_LOST)

        with patch.object(TreeNode, 'on_created') as on_created:
            self.cache._root._call_client('exists', '/')
            self.cache._root._call_client('get', '/')
            self.cache._root._call_client('get_children', '/')

            self.wait_cache(since=TreeEvent.INITIALIZED)
            on_created.assert_not_called()
            eq_(self.cache._outstanding_ops, 0)
Exemplo n.º 2
0
class KazooTreeCacheTests(KazooTestCase):
    def setUp(self):
        super(KazooTreeCacheTests, self).setUp()
        self._event_queue = self.client.handler.queue_impl()
        self._error_queue = self.client.handler.queue_impl()
        self.path = None
        self.cache = None

    def tearDown(self):
        super(KazooTreeCacheTests, self).tearDown()
        if not self._error_queue.empty():
            try:
                raise self._error_queue.get()
            except FakeException:
                pass

    def make_cache(self):
        if self.cache is None:
            self.path = '/' + uuid.uuid4().hex
            self.cache = TreeCache(self.client, self.path)
            self.cache.listen(lambda event: self._event_queue.put(event))
            self.cache.listen_fault(lambda error: self._error_queue.put(error))
            self.cache.start()
        return self.cache

    def wait_cache(self, expect=None, since=None, timeout=10):
        started = since is None
        while True:
            event = self._event_queue.get(timeout=timeout)
            if started:
                if expect is not None:
                    eq_(event.event_type, expect)
                return event
            if event.event_type == since:
                started = True
                if expect is None:
                    return

    def spy_client(self, method_name):
        method = getattr(self.client, method_name)
        return patch.object(self.client, method_name, wraps=method)

    def test_start(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        stat = self.client.exists(self.path)
        eq_(stat.version, 0)

        eq_(self.cache._state, TreeCache.STATE_STARTED)
        eq_(self.cache._root._state, TreeNode.STATE_LIVE)

    @raises(KazooException)
    def test_start_started(self):
        self.make_cache()
        self.cache.start()

    @raises(KazooException)
    def test_start_closed(self):
        self.make_cache()
        self.cache.start()
        self.cache.close()
        self.cache.start()

    def test_close(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.client.create(self.path + '/foo/bar/baz', makepath=True)
        for _ in range(3):
            self.wait_cache(TreeEvent.NODE_ADDED)

        self.cache.close()

        # nothing should be published since tree closed
        ok_(self._event_queue.empty())

        # tree should be empty
        eq_(self.cache._root._children, {})
        eq_(self.cache._root._data, None)
        eq_(self.cache._state, TreeCache.STATE_CLOSED)

        # node state should not be changed
        assert_not_equal(self.cache._root._state, TreeNode.STATE_DEAD)

    def test_children_operation(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        self.client.create(self.path + '/test_children', b'test_children_1')
        event = self.wait_cache(TreeEvent.NODE_ADDED)
        eq_(event.event_type, TreeEvent.NODE_ADDED)
        eq_(event.event_data.path, self.path + '/test_children')
        eq_(event.event_data.data, b'test_children_1')
        eq_(event.event_data.stat.version, 0)

        self.client.set(self.path + '/test_children', b'test_children_2')
        event = self.wait_cache(TreeEvent.NODE_UPDATED)
        eq_(event.event_type, TreeEvent.NODE_UPDATED)
        eq_(event.event_data.path, self.path + '/test_children')
        eq_(event.event_data.data, b'test_children_2')
        eq_(event.event_data.stat.version, 1)

        self.client.delete(self.path + '/test_children')
        event = self.wait_cache(TreeEvent.NODE_REMOVED)
        eq_(event.event_type, TreeEvent.NODE_REMOVED)
        eq_(event.event_data.path, self.path + '/test_children')
        eq_(event.event_data.data, b'test_children_2')
        eq_(event.event_data.stat.version, 1)

    def test_subtree_operation(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        self.client.create(self.path + '/foo/bar/baz', makepath=True)
        for relative_path in ('/foo', '/foo/bar', '/foo/bar/baz'):
            event = self.wait_cache(TreeEvent.NODE_ADDED)
            eq_(event.event_type, TreeEvent.NODE_ADDED)
            eq_(event.event_data.path, self.path + relative_path)
            eq_(event.event_data.data, b'')
            eq_(event.event_data.stat.version, 0)

        self.client.delete(self.path + '/foo', recursive=True)
        for relative_path in ('/foo/bar/baz', '/foo/bar', '/foo'):
            event = self.wait_cache(TreeEvent.NODE_REMOVED)
            eq_(event.event_type, TreeEvent.NODE_REMOVED)
            eq_(event.event_data.path, self.path + relative_path)

    def test_get_data(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_data(self.path).data, b'')
            eq_(cache.get_data(self.path).stat.version, 0)

            eq_(cache.get_data(self.path + '/foo').data, b'')
            eq_(cache.get_data(self.path + '/foo').stat.version, 0)

            eq_(cache.get_data(self.path + '/foo/bar').data, b'')
            eq_(cache.get_data(self.path + '/foo/bar').stat.version, 0)

            eq_(cache.get_data(self.path + '/foo/bar/baz').data, b'@')
            eq_(cache.get_data(self.path + '/foo/bar/baz').stat.version, 0)

    def test_get_children(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_children(self.path + '/foo/bar/baz'), frozenset())
            eq_(cache.get_children(self.path + '/foo/bar'), frozenset(['baz']))
            eq_(cache.get_children(self.path + '/foo'), frozenset(['bar']))
            eq_(cache.get_children(self.path), frozenset(['foo']))

    @raises(ValueError)
    def test_get_data_out_of_tree(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.cache.get_data('/out_of_tree')

    @raises(ValueError)
    def test_get_children_out_of_tree(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.cache.get_children('/out_of_tree')

    def test_get_data_no_node(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_data(self.path + '/non_exists'), None)

    def test_get_children_no_node(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_children(self.path + '/non_exists'), None)

    def test_session_reconnected(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        self.client.create(self.path + '/foo')
        event = self.wait_cache(TreeEvent.NODE_ADDED)
        eq_(event.event_data.path, self.path + '/foo')

        with self.spy_client('get_async') as get_data:
            with self.spy_client('get_children_async') as get_children:
                # session suspended
                self.lose_connection(self.client.handler.event_object)
                self.wait_cache(TreeEvent.CONNECTION_SUSPENDED)

                # There are a serial refreshing operation here. But NODE_ADDED
                # events will not be raised because the zxid of nodes are the
                # same during reconnecting.

                # connection restore
                self.wait_cache(TreeEvent.CONNECTION_RECONNECTED)

                # wait for outstanding operations
                while self.cache._outstanding_ops > 0:
                    self.client.handler.sleep_func(0.1)

                # inspect in-memory nodes
                _node_root = self.cache._root
                _node_foo = self.cache._root._children['foo']

                # make sure that all nodes are refreshed
                get_data.assert_has_calls([
                    call(self.path, watch=_node_root._process_watch),
                    call(self.path + '/foo', watch=_node_foo._process_watch),
                ],
                                          any_order=True)
                get_children.assert_has_calls([
                    call(self.path, watch=_node_root._process_watch),
                    call(self.path + '/foo', watch=_node_foo._process_watch),
                ],
                                              any_order=True)

    def test_root_recreated(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        # remove root node
        self.client.delete(self.path)
        event = self.wait_cache(TreeEvent.NODE_REMOVED)
        eq_(event.event_type, TreeEvent.NODE_REMOVED)
        eq_(event.event_data.data, b'')
        eq_(event.event_data.path, self.path)
        eq_(event.event_data.stat.version, 0)

        # re-create root node
        self.client.ensure_path(self.path)
        event = self.wait_cache(TreeEvent.NODE_ADDED)
        eq_(event.event_type, TreeEvent.NODE_ADDED)
        eq_(event.event_data.data, b'')
        eq_(event.event_data.path, self.path)
        eq_(event.event_data.stat.version, 0)

        self.assertTrue(
            self.cache._outstanding_ops >= 0,
            'unexpected outstanding ops %r' % self.cache._outstanding_ops)

    def test_exception_handler(self):
        error_value = FakeException()
        error_handler = Mock()

        with patch.object(TreeNode, 'on_deleted') as on_deleted:
            on_deleted.side_effect = [error_value]

            self.make_cache()
            self.cache.listen_fault(error_handler)

            self.cache.close()
            error_handler.assert_called_once_with(error_value)
Exemplo n.º 3
0
class ZookeeperClusterManager(ClusterManager):
    """
    A cluster manager that manages one cluster's state and configurations
    with a Zookeeper ensemble via kazoo.

    Below is the structure of the znodes:
        /needlestack
            /<CLUSTER_NAME_1>
                /live_nodes
                    /<HOSTPORT_1>
                    /<HOSTPORT_2>
                    /<HOSTPORT_3>
                    /<HOSTPORT_4>
                    ...
                /collections
                    /<COLLECTION_NAME_1>
                        /shards
                            /<SHARD_NAME_1>
                                /replicas
                                    /<HOSTPORT_2>
                                    /<HOSTPORT_4>
                            /<SHARD_NAME_2>
                                /replicas
                                    /<HOSTPORT_1>
                                    /<HOSTPORT_3>
                    /<COLLECTION_NAME_2>
                        ...
    """

    cluster_name: str
    hostport: str
    zk: KazooClient
    cache: TreeCache

    def __init__(self, cluster_name: str, hostport: str, hosts: List[str],
                 zookeeper_root: str):
        self.cluster_name = cluster_name
        self.hostport = hostport
        self.zookeeper_root = zookeeper_root
        self.zk = KazooClient(hosts=hosts)
        self.zk.add_listener(self.zk_listener)
        self.cache = TreeCache(self.zk, self.base_znode)

    @property
    def base_znode(self):
        return f"{self.zookeeper_root}/{self.cluster_name}"

    @property
    def live_nodes_znode(self):
        return f"{self.base_znode}/live_nodes"

    @property
    def this_node_znode(self):
        return f"{self.base_znode}/live_nodes/{self.hostport}"

    @property
    def collections_znode(self):
        return f"{self.base_znode}/collections"

    def collection_znode(self, collection_name: str) -> str:
        return f"{self.collections_znode}/{collection_name}"

    def shard_znode(self, collection_name: str, shard_name: str = None) -> str:
        znode = f"{self.collections_znode}/{collection_name}/shards"
        if shard_name:
            znode += "/" + shard_name
        return znode

    def replica_znode(self,
                      collection_name: str,
                      shard_name: str,
                      hostport: str = None) -> str:
        shard_znode = self.shard_znode(collection_name, shard_name)
        znode = f"{shard_znode}/replicas"
        if hostport:
            znode += "/" + hostport
        return znode

    def startup(self):
        self.zk.start()
        self.cache.start()
        signal.signal(signal.SIGINT, self.signal_listener)
        signal.signal(signal.SIGTERM, self.signal_listener)
        self.zk.ensure_path(self.live_nodes_znode)
        self.zk.ensure_path(self.collections_znode)

    def shutdown(self):
        self.cache.close()
        self.zk.stop()
        self.zk.close()

    def cleanup(self):
        logger.info(f"Removing ZNodes via cleanup")
        transaction = self.zk.transaction()

        for collection in self.list_local_collections():
            for shard in collection.shards:
                for replica in shard.replicas:
                    znode = self.replica_znode(collection.name, shard.name,
                                               replica.hostport)
                    transaction.delete(znode)

        self.commit_transaction(transaction)

    def register_merger(self):
        pass

    def register_searcher(self):
        try:
            retrier = KazooRetry(max_tries=5, delay=1, backoff=2, max_delay=20)
            retrier(self.zk.create,
                    self.this_node_znode,
                    ephemeral=True,
                    makepath=True)
            logger.info(f"Created ephemeral ZNode {self.this_node_znode}")
        except kazoo.retry.RetryFailedError:
            logger.error(
                f"Max retries reached for creating ephemeral ZNode {self.this_node_znode}"
            )
        except kazoo.retry.InterruptedError:
            logger.error(
                f"Retries interrupted for creating ephemeral ZNode {self.this_node_znode}"
            )

    def set_state(self,
                  state,
                  collection_name=None,
                  shard_name=None,
                  hostport=None):
        transaction = self.zk.transaction()

        collections = [collection_name] if collection_name else None
        for collection in self._list_collections(collections,
                                                 hostport=hostport,
                                                 load_replica=True):
            logger.info(
                f"Set {collection.name}/shards ZNodes to {collections_pb2.Replica.State.Name(state)}"
            )
            for shard in collection.shards:
                for replica in shard.replicas:
                    znode = self.replica_znode(collection.name, shard.name,
                                               replica.node.hostport)
                    replica.state = state
                    transaction.set_data(znode, replica.SerializeToString())

        return self.commit_transaction(transaction)

    def set_local_state(self, state, collection_name=None, shard_name=None):
        return self.set_state(state, collection_name, shard_name,
                              self.hostport)

    def signal_listener(self, signum, frame):
        self.shutdown()

    def zk_listener(self, state):
        if state == KazooState.LOST:
            logger.warn("Connection to Zookeeper lost")
        elif state == KazooState.SUSPENDED:
            logger.warn("Connection to Zookeeper disconnected")
        else:
            logger.info("Connection to Zookeeper established")

    def add_collections(self, collections):
        """Configure a list of collections into Zookeeper
        """
        transaction = self.zk.transaction()

        for collection in collections:
            collection_copy = deepcopy(collection)
            collection_copy.ClearField("shards")
            collection_znode = self.collection_znode(collection.name)
            transaction.create(collection_znode,
                               collection_copy.SerializeToString())
            transaction.create(self.shard_znode(collection.name))
            for shard in collection.shards:
                shard_copy = deepcopy(shard)
                shard_copy.ClearField("replicas")
                shard_znode = self.shard_znode(collection.name, shard.name)
                transaction.create(shard_znode, shard_copy.SerializeToString())
                transaction.create(
                    self.replica_znode(collection.name, shard.name))
                for replica in shard.replicas:
                    replica_copy = deepcopy(replica)
                    replica_copy.state = collections_pb2.Replica.BOOTING
                    replica_znode = self.replica_znode(collection.name,
                                                       shard.name,
                                                       replica.node.hostport)
                    transaction.create(replica_znode,
                                       replica_copy.SerializeToString())

        if self.commit_transaction(transaction):
            return collections
        else:
            return []

    def delete_collections(self, collection_names):
        transaction = self.zk.transaction()

        for collection_name in collection_names:
            shards_znode = self.shard_znode(collection_name)
            for shard_name in self.zk.get_children(shards_znode):
                replicas_znode = self.replica_znode(collection_name,
                                                    shard_name)
                for replica_name in self.zk.get_children(replicas_znode):
                    replica_znode = self.replica_znode(collection_name,
                                                       shard_name,
                                                       replica_name)
                    transaction.delete(replica_znode)
                transaction.delete(replicas_znode)
                transaction.delete(
                    self.shard_znode(collection_name, shard_name))
            transaction.delete(shards_znode)
            transaction.delete(self.collection_znode(collection_name))

        if self.commit_transaction(transaction):
            return collection_names
        else:
            return []

    def list_nodes(self):
        live_nodes = self.zk.get_children(self.live_nodes_znode)
        nodes = [collections_pb2.Node(hostport=node) for node in live_nodes]
        return nodes

    def list_collections(self, collection_names=None, include_state=True):
        return self._list_collections(collection_names,
                                      load_replica=include_state)

    def list_local_collections(self, include_state=True):
        return self._list_collections(hostport=self.hostport,
                                      load_replica=include_state)

    def _list_collections(
        self,
        collection_names: Optional[List[str]] = None,
        hostport: Optional[str] = None,
        load_replica: Optional[bool] = True,
    ) -> List[collections_pb2.Collection]:
        collections = []

        collection_names = collection_names or self.zk.get_children(
            self.collections_znode)
        for collection_name in collection_names:

            shards = []
            shards_znode = self.shard_znode(collection_name)
            for shard_name in self.zk.get_children(shards_znode):

                replicas = []
                replicas_znode = self.replica_znode(collection_name,
                                                    shard_name)
                for replica_hostport in self.zk.get_children(replicas_znode):
                    if hostport == replica_hostport or hostport is None:
                        replica_znode = self.replica_znode(
                            collection_name, shard_name, replica_hostport)
                        if load_replica:
                            replica_data, _ = self.zk.get(replica_znode)
                            replica_proto = collections_pb2.Replica.FromString(
                                replica_data)
                        else:
                            replica_proto = collections_pb2.Replica()
                        replicas.append(replica_proto)

                if replicas:
                    shard_znode = self.shard_znode(collection_name, shard_name)
                    shard_data, _ = self.zk.get(shard_znode)
                    shard_proto = collections_pb2.Shard.FromString(shard_data)
                    shard_proto.replicas.extend(replicas)
                    shards.append(shard_proto)

            if shards:
                collection_znode = self.collection_znode(collection_name)
                collection_data, _ = self.zk.get(collection_znode)
                collection_proto = collections_pb2.Collection.FromString(
                    collection_data)
                collection_proto.shards.extend(shards)
                collections.append(collection_proto)

        return collections

    def get_searchers(self, collection_name, shard_names=None):
        if not shard_names:
            shards_znode = self.shard_znode(collection_name)
            shard_names = self.cache.get_children(shards_znode, [])

        shard_hostports = []
        for shard_name in shard_names:
            hostports = self._get_searchers_for_shard(collection_name,
                                                      shard_name,
                                                      active=True)
            if hostports:
                shard_hostports.append((shard_name, hostports))
            else:
                logger.error(
                    f"No active Searcher node for {collection_name}/{shard_name}."
                )

        return shard_hostports

    def _get_searchers_for_shard(self,
                                 collection_name: str,
                                 shard_name: str,
                                 active: bool = True) -> List[str]:
        replicas_znode = self.replica_znode(collection_name, shard_name)
        hostports = self.cache.get_children(replicas_znode, [])

        if active:
            active_hostports = []
            for hostport in hostports:
                replica_znode = self.replica_znode(collection_name, shard_name,
                                                   hostport)
                node = self.cache.get_data(replica_znode)
                if node:
                    replica = collections_pb2.Replica.FromString(node.data)
                    if replica.state == collections_pb2.Replica.ACTIVE:
                        active_hostports.append(hostport)
            hostports = active_hostports

        return hostports

    def commit_transaction(
            self, transaction: kazoo.client.TransactionRequest) -> bool:
        """Commit a transaction and log the first exception after rollbacks"""
        for result, operation in zip(transaction.commit(),
                                     transaction.operations):
            if isinstance(result, kazoo.exceptions.RolledBackError):
                continue
            elif isinstance(result, Exception):
                logger.error(
                    f"{result.__class__.__name__} in Kazoo transaction: {operation}"
                )
                return False
        return True
Exemplo n.º 4
0
class KazooTreeCacheTests(KazooTestCase):

    def setUp(self):
        super(KazooTreeCacheTests, self).setUp()
        self._event_queue = self.client.handler.queue_impl()
        self._error_queue = self.client.handler.queue_impl()
        self.path = None
        self.cache = None

    def tearDown(self):
        super(KazooTreeCacheTests, self).tearDown()
        if not self._error_queue.empty():
            try:
                raise self._error_queue.get()
            except FakeException:
                pass

    def make_cache(self):
        if self.cache is None:
            self.path = '/' + uuid.uuid4().hex
            self.cache = TreeCache(self.client, self.path)
            self.cache.listen(lambda event: self._event_queue.put(event))
            self.cache.listen_fault(lambda error: self._error_queue.put(error))
            self.cache.start()
        return self.cache

    def wait_cache(self, expect=None, since=None, timeout=10):
        started = since is None
        while True:
            event = self._event_queue.get(timeout=timeout)
            if started:
                if expect is not None:
                    eq_(event.event_type, expect)
                return event
            if event.event_type == since:
                started = True
                if expect is None:
                    return

    def spy_client(self, method_name):
        method = getattr(self.client, method_name)
        return patch.object(self.client, method_name, wraps=method)

    def test_start(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        stat = self.client.exists(self.path)
        eq_(stat.version, 0)

        eq_(self.cache._state, TreeCache.STATE_STARTED)
        eq_(self.cache._root._state, TreeNode.STATE_LIVE)

    @raises(KazooException)
    def test_start_started(self):
        self.make_cache()
        self.cache.start()

    @raises(KazooException)
    def test_start_closed(self):
        self.make_cache()
        self.cache.start()
        self.cache.close()
        self.cache.start()

    def test_close(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.client.create(self.path + '/foo/bar/baz', makepath=True)
        for _ in range(3):
            self.wait_cache(TreeEvent.NODE_ADDED)

        self.cache.close()

        # nothing should be published since tree closed
        ok_(self._event_queue.empty())

        # tree should be empty
        eq_(self.cache._root._children, {})
        eq_(self.cache._root._data, None)
        eq_(self.cache._state, TreeCache.STATE_CLOSED)

        # node state should not be changed
        assert_not_equal(self.cache._root._state, TreeNode.STATE_DEAD)

    def test_children_operation(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        self.client.create(self.path + '/test_children', b'test_children_1')
        event = self.wait_cache(TreeEvent.NODE_ADDED)
        eq_(event.event_type, TreeEvent.NODE_ADDED)
        eq_(event.event_data.path, self.path + '/test_children')
        eq_(event.event_data.data, b'test_children_1')
        eq_(event.event_data.stat.version, 0)

        self.client.set(self.path + '/test_children', b'test_children_2')
        event = self.wait_cache(TreeEvent.NODE_UPDATED)
        eq_(event.event_type, TreeEvent.NODE_UPDATED)
        eq_(event.event_data.path, self.path + '/test_children')
        eq_(event.event_data.data, b'test_children_2')
        eq_(event.event_data.stat.version, 1)

        self.client.delete(self.path + '/test_children')
        event = self.wait_cache(TreeEvent.NODE_REMOVED)
        eq_(event.event_type, TreeEvent.NODE_REMOVED)
        eq_(event.event_data.path, self.path + '/test_children')
        eq_(event.event_data.data, b'test_children_2')
        eq_(event.event_data.stat.version, 1)

    def test_subtree_operation(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        self.client.create(self.path + '/foo/bar/baz', makepath=True)
        for relative_path in ('/foo', '/foo/bar', '/foo/bar/baz'):
            event = self.wait_cache(TreeEvent.NODE_ADDED)
            eq_(event.event_type, TreeEvent.NODE_ADDED)
            eq_(event.event_data.path, self.path + relative_path)
            eq_(event.event_data.data, b'')
            eq_(event.event_data.stat.version, 0)

        self.client.delete(self.path + '/foo', recursive=True)
        for relative_path in ('/foo/bar/baz', '/foo/bar', '/foo'):
            event = self.wait_cache(TreeEvent.NODE_REMOVED)
            eq_(event.event_type, TreeEvent.NODE_REMOVED)
            eq_(event.event_data.path, self.path + relative_path)

    def test_get_data(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_data(self.path).data, b'')
            eq_(cache.get_data(self.path).stat.version, 0)

            eq_(cache.get_data(self.path + '/foo').data, b'')
            eq_(cache.get_data(self.path + '/foo').stat.version, 0)

            eq_(cache.get_data(self.path + '/foo/bar').data, b'')
            eq_(cache.get_data(self.path + '/foo/bar').stat.version, 0)

            eq_(cache.get_data(self.path + '/foo/bar/baz').data, b'@')
            eq_(cache.get_data(self.path + '/foo/bar/baz').stat.version, 0)

    def test_get_children(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.client.create(self.path + '/foo/bar/baz', b'@', makepath=True)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)
        self.wait_cache(TreeEvent.NODE_ADDED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_children(self.path + '/foo/bar/baz'), frozenset())
            eq_(cache.get_children(self.path + '/foo/bar'), frozenset(['baz']))
            eq_(cache.get_children(self.path + '/foo'), frozenset(['bar']))
            eq_(cache.get_children(self.path), frozenset(['foo']))

    @raises(ValueError)
    def test_get_data_out_of_tree(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.cache.get_data('/out_of_tree')

    @raises(ValueError)
    def test_get_children_out_of_tree(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)
        self.cache.get_children('/out_of_tree')

    def test_get_data_no_node(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_data(self.path + '/non_exists'), None)

    def test_get_children_no_node(self):
        cache = self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        with patch.object(cache, '_client'):  # disable any remote operation
            eq_(cache.get_children(self.path + '/non_exists'), None)

    def test_session_reconnected(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        self.client.create(self.path + '/foo')
        event = self.wait_cache(TreeEvent.NODE_ADDED)
        eq_(event.event_data.path, self.path + '/foo')

        with self.spy_client('get_async') as get_data:
            with self.spy_client('get_children_async') as get_children:
                # session suspended
                self.lose_connection(self.client.handler.event_object)
                self.wait_cache(TreeEvent.CONNECTION_SUSPENDED)

                # There are a serial refreshing operation here. But NODE_ADDED
                # events will not be raised because the zxid of nodes are the
                # same during reconnecting.

                # connection restore
                self.wait_cache(TreeEvent.CONNECTION_RECONNECTED)

                # wait for outstanding operations
                while self.cache._outstanding_ops > 0:
                    self.client.handler.sleep_func(0.1)

                # inspect in-memory nodes
                _node_root = self.cache._root
                _node_foo = self.cache._root._children['foo']

                # make sure that all nodes are refreshed
                get_data.assert_has_calls([
                    call(self.path, watch=_node_root._process_watch),
                    call(self.path + '/foo', watch=_node_foo._process_watch),
                ], any_order=True)
                get_children.assert_has_calls([
                    call(self.path, watch=_node_root._process_watch),
                    call(self.path + '/foo', watch=_node_foo._process_watch),
                ], any_order=True)

    def test_root_recreated(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        # remove root node
        self.client.delete(self.path)
        event = self.wait_cache(TreeEvent.NODE_REMOVED)
        eq_(event.event_type, TreeEvent.NODE_REMOVED)
        eq_(event.event_data.data, b'')
        eq_(event.event_data.path, self.path)
        eq_(event.event_data.stat.version, 0)

        # re-create root node
        self.client.ensure_path(self.path)
        event = self.wait_cache(TreeEvent.NODE_ADDED)
        eq_(event.event_type, TreeEvent.NODE_ADDED)
        eq_(event.event_data.data, b'')
        eq_(event.event_data.path, self.path)
        eq_(event.event_data.stat.version, 0)

        self.assertTrue(
            self.cache._outstanding_ops >= 0,
            'unexpected outstanding ops %r' % self.cache._outstanding_ops)

    def test_exception_handler(self):
        error_value = FakeException()
        error_handler = Mock()

        with patch.object(TreeNode, 'on_deleted') as on_deleted:
            on_deleted.side_effect = [error_value]

            self.make_cache()
            self.cache.listen_fault(error_handler)

            self.cache.close()
            error_handler.assert_called_once_with(error_value)

    def test_exception_suppressed(self):
        self.make_cache()
        self.wait_cache(since=TreeEvent.INITIALIZED)

        # stoke up ConnectionClosedError
        self.client.stop()
        self.client.close()
        self.client.handler.start()  # keep the async completion
        self.wait_cache(since=TreeEvent.CONNECTION_LOST)

        with patch.object(TreeNode, 'on_created') as on_created:
            self.cache._root._call_client('exists', '/')
            self.cache._root._call_client('get', '/')
            self.cache._root._call_client('get_children', '/')

            self.wait_cache(since=TreeEvent.INITIALIZED)
            on_created.assert_not_called()
            eq_(self.cache._outstanding_ops, 0)