예제 #1
0
async def test_initialize_follow_cluster():
    n = NodeManager(nodemanager_follow_cluster=True,
                    startup_nodes=[{
                        'host': '127.0.0.1',
                        'port': 7000
                    }])
    n.orig_startup_nodes = None
    await n.initialize()
예제 #2
0
def test_reset():
    """
    Test that reset method resets variables back to correct default values.
    """
    n = NodeManager(startup_nodes=[{}])
    n.initialize = Mock()
    n.reset()
    assert n.initialize.call_count == 0
예제 #3
0
async def test_init_slots_cache_slots_collision():
    """
    Test that if 2 nodes do not agree on the same slots setup it should raise an error.
    In this test both nodes will say that the first slots block should be bound to different
     servers.
    """

    n = NodeManager(startup_nodes=[
        {
            "host": "127.0.0.1",
            "port": 7000
        },
        {
            "host": "127.0.0.1",
            "port": 7001
        },
    ])

    def monkey_link(host=None, port=None, *args, **kwargs):
        """
        Helper function to return custom slots cache data from different redis nodes
        """
        if port == 7000:
            result = [[0, 5460, [b'127.0.0.1', 7000], [b'127.0.0.1', 7003]],
                      [
                          5461, 10922, [b'127.0.0.1', 7001],
                          [b'127.0.0.1', 7004]
                      ]]

        elif port == 7001:
            result = [[0, 5460, [b'127.0.0.1', 7001], [b'127.0.0.1', 7003]],
                      [
                          5461, 10922, [b'127.0.0.1', 7000],
                          [b'127.0.0.1', 7004]
                      ]]

        else:
            result = []

        r = StrictRedisCluster(host=host, port=port, decode_responses=True)
        orig_execute_command = r.execute_command

        def execute_command(*args, **kwargs):
            if args == ("cluster", "slots"):
                return result
            elif args == ('CONFIG GET', 'cluster-require-full-coverage'):
                return {'cluster-require-full-coverage': 'yes'}
            else:
                return orig_execute_command(*args, **kwargs)

        r.execute_command = execute_command
        return r

    n.get_redis_link = monkey_link
    with pytest.raises(RedisClusterException) as ex:
        await n.initialize()
    assert str(ex.value).startswith(
        "startup_nodes could not agree on a valid slots cache."), str(ex.value)
예제 #4
0
def test_empty_startup_nodes():
    """
    It should not be possible to create a node manager with no nodes specefied
    """
    with pytest.raises(RedisClusterException):
        NodeManager()

    with pytest.raises(RedisClusterException):
        NodeManager([])
예제 #5
0
    def __init__(self,
                 startup_nodes=None,
                 connection_class=ClusterConnection,
                 max_connections=None,
                 max_connections_per_node=False,
                 reinitialize_steps=None,
                 skip_full_coverage_check=False,
                 nodemanager_follow_cluster=False,
                 readonly=False,
                 max_idle_time=0,
                 idle_check_interval=1,
                 **connection_kwargs):
        """
        :skip_full_coverage_check:
            Skips the check of cluster-require-full-coverage config, useful for clusters
            without the CONFIG command (like aws)
        :nodemanager_follow_cluster:
            The node manager will during initialization try the last set of nodes that
            it was operating on. This will allow the client to drift along side the cluster
            if the cluster nodes move around alot.
        """
        super(ClusterConnectionPool,
              self).__init__(connection_class=connection_class,
                             max_connections=max_connections)

        # Special case to make from_url method compliant with cluster setting.
        # from_url method will send in the ip and port through a different variable then the
        # regular startup_nodes variable.
        if startup_nodes is None:
            if 'port' in connection_kwargs and 'host' in connection_kwargs:
                startup_nodes = [{
                    'host': connection_kwargs.pop('host'),
                    'port': str(connection_kwargs.pop('port')),
                }]

        self.max_connections = max_connections or 2**31
        self.max_connections_per_node = max_connections_per_node
        self.nodes = NodeManager(
            startup_nodes,
            reinitialize_steps=reinitialize_steps,
            skip_full_coverage_check=skip_full_coverage_check,
            max_connections=self.max_connections,
            nodemanager_follow_cluster=nodemanager_follow_cluster,
            **connection_kwargs)
        self.initialized = False

        self.connections = {}
        self.connection_kwargs = connection_kwargs
        self.connection_kwargs['readonly'] = readonly
        self.readonly = readonly
        self.max_idle_time = max_idle_time
        self.idle_check_interval = idle_check_interval
        self.reset()

        if "stream_timeout" not in self.connection_kwargs:
            self.connection_kwargs[
                "stream_timeout"] = ClusterConnectionPool.RedisClusterDefaultTimeout
예제 #6
0
def test_random_startup_node():
    """
    Hard to test reliable for a random
    """
    s = [{"1": 1}, {"2": 2}, {"3": 3}]
    n = NodeManager(startup_nodes=s)
    random_node = n.random_startup_node()

    for i in range(0, 5):
        assert random_node in s
예제 #7
0
async def test_all_nodes():
    """
    Set a list of nodes and it should be possible to itterate over all
    """
    n = NodeManager(startup_nodes=[{"host": "127.0.0.1", "port": 7000}])
    await n.initialize()

    nodes = [node for node in n.nodes.values()]

    for i, node in enumerate(n.all_nodes()):
        assert node in nodes
예제 #8
0
async def test_all_nodes_masters():
    """
    Set a list of nodes with random masters/slaves config and it shold be possible
    to itterate over all of them.
    """
    n = NodeManager(startup_nodes=[{"host": "127.0.0.1", "port": 7000}, {"host": "127.0.0.1", "port": 7001}])
    await n.initialize()

    nodes = [node for node in n.nodes.values() if node['server_type'] == 'master']

    for node in n.all_masters():
        assert node in nodes
예제 #9
0
async def test_reset():
    """
    Test that reset method resets variables back to correct default values.
    """
    class AsyncMock(Mock):
        def __await__(self):
            future = asyncio.Future(loop=asyncio.get_event_loop())
            future.set_result(self)
            result = yield from future
            return result

    n = NodeManager(startup_nodes=[{}])
    n.initialize = AsyncMock()
    await n.reset()
    assert n.initialize.call_count == 1
예제 #10
0
def test_set_node():
    """
    Test to update data in a slot.
    """
    expected = {
        "host": "127.0.0.1",
        "name": "127.0.0.1:7000",
        "port": 7000,
        "server_type": "master",
    }

    n = NodeManager(startup_nodes=[{}])
    assert len(n.slots) == 0, "no slots should exist"
    res = n.set_node(host="127.0.0.1", port=7000, server_type="master")
    assert res == expected
    assert n.nodes == {expected['name']: expected}
예제 #11
0
async def test_cluster_slots_error():
    """
    Check that exception is raised if initialize can't execute
    'CLUSTER SLOTS' command.
    """
    with patch.object(StrictRedisCluster, 'execute_command') as execute_command_mock:
        execute_command_mock.side_effect = Exception("foobar")

        n = NodeManager(startup_nodes=[{}])

        with pytest.raises(RedisClusterException):
            await n.initialize()
예제 #12
0
async def test_init_with_down_node():
    """
    If I can't connect to one of the nodes, everything should still work.
    But if I can't connect to any of the nodes, exception should be thrown.
    """
    def get_redis_link(host, port, decode_responses=False):
        if port == 7000:
            raise ConnectionError('mock connection error for 7000')
        return StrictRedis(host=host, port=port, decode_responses=decode_responses)

    with patch.object(NodeManager, 'get_redis_link', side_effect=get_redis_link):
        n = NodeManager(startup_nodes=[{"host": "127.0.0.1", "port": 7000}])
        with pytest.raises(RedisClusterException) as e:
            await n.initialize()
        assert 'Redis Cluster cannot be connected' in str(e.value)
예제 #13
0
async def test_cluster_one_instance():
    """
    If the cluster exists of only 1 node then there is some hacks that must
    be validated they work.
    """
    with patch.object(StrictRedis, 'execute_command') as mock_execute_command:
        return_data = [[0, 16383, ['', 7006]]]

        def patch_execute_command(*args, **kwargs):
            if args == ('CONFIG GET', 'cluster-require-full-coverage'):
                return {'cluster-require-full-coverage': 'yes'}
            else:
                return return_data

        # mock_execute_command.return_value = return_data
        mock_execute_command.side_effect = patch_execute_command

        n = NodeManager(startup_nodes=[{"host": "127.0.0.1", "port": 7006}])
        await n.initialize()

        assert n.nodes == {
            "127.0.0.1:7006": {
                'host': '127.0.0.1',
                'name': '127.0.0.1:7006',
                'port': 7006,
                'server_type': 'master',
            }
        }

        assert len(n.slots) == 16384
        for i in range(0, 16384):
            assert n.slots[i] == [{
                "host": "127.0.0.1",
                "name": "127.0.0.1:7006",
                "port": 7006,
                "server_type": "master",
            }]
예제 #14
0
class ClusterConnectionPool(ConnectionPool):
    """
    Custom connection pool for rediscluster
    """
    RedisClusterDefaultTimeout = None

    def __init__(self,
                 startup_nodes=None,
                 connection_class=ClusterConnection,
                 max_connections=None,
                 max_connections_per_node=False,
                 reinitialize_steps=None,
                 skip_full_coverage_check=False,
                 nodemanager_follow_cluster=False,
                 readonly=False,
                 **connection_kwargs):
        """
        :skip_full_coverage_check:
            Skips the check of cluster-require-full-coverage config, useful for clusters
            without the CONFIG command (like aws)
        :nodemanager_follow_cluster:
            The node manager will during initialization try the last set of nodes that
            it was operating on. This will allow the client to drift along side the cluster
            if the cluster nodes move around alot.
        """
        super(ClusterConnectionPool,
              self).__init__(connection_class=connection_class,
                             max_connections=max_connections)

        # Special case to make from_url method compliant with cluster setting.
        # from_url method will send in the ip and port through a different variable then the
        # regular startup_nodes variable.
        if startup_nodes is None:
            if 'port' in connection_kwargs and 'host' in connection_kwargs:
                startup_nodes = [{
                    'host': connection_kwargs.pop('host'),
                    'port': str(connection_kwargs.pop('port')),
                }]

        self.max_connections = max_connections or 2**31
        self.max_connections_per_node = max_connections_per_node
        self.nodes = NodeManager(
            startup_nodes,
            reinitialize_steps=reinitialize_steps,
            skip_full_coverage_check=skip_full_coverage_check,
            max_connections=self.max_connections,
            nodemanager_follow_cluster=nodemanager_follow_cluster,
            **connection_kwargs)
        self.initialized = False

        self.connections = {}
        self.connection_kwargs = connection_kwargs
        self.connection_kwargs['readonly'] = readonly
        self.readonly = readonly
        self.reset()

        if "stream_timeout" not in self.connection_kwargs:
            self.connection_kwargs[
                "stream_timeout"] = ClusterConnectionPool.RedisClusterDefaultTimeout

    def __repr__(self):
        """
        Return a string with all unique ip:port combinations that this pool is connected to.
        """
        return "{0}<{1}>".format(
            type(self).__name__, ", ".join([
                self.connection_class.description.format(**node)
                for node in self.nodes.startup_nodes
            ]))

    async def initialize(self):
        if not self.initialized:
            await self.nodes.initialize()
            self.initialized = True

    def reset(self):
        """
        Resets the connection pool back to a clean state.
        """
        self.pid = os.getpid()
        self._created_connections = 0
        self._created_connections_per_node = {}  # Dict(Node, Int)
        self._available_connections = {}  # Dict(Node, List)
        self._in_use_connections = {}  # Dict(Node, Set)
        self._check_lock = threading.Lock()
        self.initialized = False

    def _checkpid(self):
        """
        """
        if self.pid != os.getpid():
            with self._check_lock:
                if self.pid == os.getpid():
                    # another thread already did the work while we waited
                    # on the lockself.
                    return
                self.disconnect()
                self.reset()

    def get_connection(self, command_name, *keys, **options):
        # Only pubsub command/connection should be allowed here
        if command_name != "pubsub":
            raise RedisClusterException(
                "Only 'pubsub' commands can use get_connection()")

        channel = options.pop('channel', None)

        if not channel:
            return self.get_random_connection()

        slot = self.nodes.keyslot(channel)
        node = self.get_master_node_by_slot(slot)

        self._checkpid()

        try:
            connection = self._available_connections.get(node["name"],
                                                         []).pop()
        except IndexError:
            connection = self.make_connection(node)

        if node['name'] not in self._in_use_connections:
            self._in_use_connections[node['name']] = set()

        self._in_use_connections[node['name']].add(connection)

        return connection

    def make_connection(self, node):
        """
        Create a new connection
        """
        if self.count_all_num_connections(node) >= self.max_connections:
            if self.max_connections_per_node:
                raise RedisClusterException(
                    "Too many connection ({0}) for node: {1}".format(
                        self.count_all_num_connections(node), node['name']))

            raise RedisClusterException("Too many connections")

        self._created_connections_per_node.setdefault(node['name'], 0)
        self._created_connections_per_node[node['name']] += 1
        connection = self.connection_class(host=node["host"],
                                           port=node["port"],
                                           **self.connection_kwargs)

        # Must store node in the connection to make it eaiser to track
        connection.node = node

        return connection

    def release(self, connection):
        """
        Releases the connection back to the pool
        """
        self._checkpid()
        if connection.pid != self.pid:
            return

        # Remove the current connection from _in_use_connection and add it back to the available pool
        # There is cases where the connection is to be removed but it will not exist and there
        # must be a safe way to remove
        i_c = self._in_use_connections.get(connection.node["name"], set())
        if connection in i_c:
            i_c.remove(connection)
        else:
            pass
        self._available_connections.setdefault(connection.node["name"],
                                               []).append(connection)

    def disconnect(self):
        """
        Nothing that requires any overwrite.
        """
        all_conns = chain(
            self._available_connections.values(),
            self._in_use_connections.values(),
        )

        for node_connections in all_conns:
            for connection in node_connections:
                connection.disconnect()

    def count_all_num_connections(self, node):
        """
        """
        if self.max_connections_per_node:
            return self._created_connections_per_node.get(node['name'], 0)

        return sum([i for i in self._created_connections_per_node.values()])

    def get_random_connection(self):
        """
        Open new connection to random redis server.
        """
        if self._available_connections:
            return random.choice(self._available_connections)
        else:
            for node in self.nodes.random_startup_node_iter():
                connection = self.get_connection_by_node(node)

                if connection:
                    return connection

        raise Exception("Cant reach a single startup node.")

    def get_connection_by_key(self, key):
        """
        """
        if not key:
            raise RedisClusterException(
                "No way to dispatch this command to Redis Cluster.")

        return self.get_connection_by_slot(self.nodes.keyslot(key))

    def get_connection_by_slot(self, slot):
        """
        Determine what server a specific slot belongs to and return a redis object that is connected
        """
        self._checkpid()

        try:
            return self.get_connection_by_node(self.get_node_by_slot(slot))
        except KeyError:
            return self.get_random_connection()

    def get_connection_by_node(self, node):
        """
        get a connection by node
        """
        self._checkpid()
        self.nodes.set_node_name(node)

        try:
            # Try to get connection from existing pool
            connection = self._available_connections.get(node["name"],
                                                         []).pop()
        except IndexError:
            connection = self.make_connection(node)

        self._in_use_connections.setdefault(node["name"],
                                            set()).add(connection)

        return connection

    def get_master_node_by_slot(self, slot):
        return self.nodes.slots[slot][0]

    def get_node_by_slot(self, slot):
        if self.readonly:
            return random.choice(self.nodes.slots[slot])
        return self.get_master_node_by_slot(slot)
예제 #15
0
def test_keyslot():
    """
    Test that method will compute correct key in all supported cases
    """
    n = NodeManager([{}])

    assert n.keyslot("foo") == 12182
    assert n.keyslot("{foo}bar") == 12182
    assert n.keyslot("{foo}") == 12182
    assert n.keyslot(1337) == 4314

    assert n.keyslot(125) == n.keyslot(b"125")
    assert n.keyslot(125) == n.keyslot("\x31\x32\x35")
    assert n.keyslot("大奖") == n.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96")
    assert n.keyslot(u"大奖") == n.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96")
    assert n.keyslot(1337.1234) == n.keyslot("1337.1234")
    assert n.keyslot(1337) == n.keyslot("1337")
    assert n.keyslot(b"abc") == n.keyslot("abc")
    assert n.keyslot("abc") == n.keyslot(str("abc"))
    assert n.keyslot(str("abc")) == n.keyslot(b"abc")
예제 #16
0
def test_wrong_startup_nodes_type():
    """
    If something other then a list type itteratable is provided it should fail
    """
    with pytest.raises(TypeError):
        NodeManager({})