Beispiel #1
0
class ClusterConnectionPool(ConnectionPool):
    """
    Custom connection pool for rediscluster
    """
    RedisClusterDefaultTimeout = None

    def __init__(self,
                 startup_nodes=None,
                 connection_class=ClusterConnection,
                 max_connections=None,
                 max_connections_per_node=False,
                 reinitialize_steps=None,
                 skip_full_coverage_check=False,
                 nodemanager_follow_cluster=False,
                 readonly=False,
                 **connection_kwargs):
        """
        :skip_full_coverage_check:
            Skips the check of cluster-require-full-coverage config, useful for clusters
            without the CONFIG command (like aws)
        :nodemanager_follow_cluster:
            The node manager will during initialization try the last set of nodes that
            it was operating on. This will allow the client to drift along side the cluster
            if the cluster nodes move around alot.
        """
        super(ClusterConnectionPool,
              self).__init__(connection_class=connection_class,
                             max_connections=max_connections)

        # Special case to make from_url method compliant with cluster setting.
        # from_url method will send in the ip and port through a different variable then the
        # regular startup_nodes variable.
        if startup_nodes is None:
            if 'port' in connection_kwargs and 'host' in connection_kwargs:
                startup_nodes = [{
                    'host': connection_kwargs.pop('host'),
                    'port': str(connection_kwargs.pop('port')),
                }]

        self.max_connections = max_connections or 2**31
        self.max_connections_per_node = max_connections_per_node
        self.nodes = NodeManager(
            startup_nodes,
            reinitialize_steps=reinitialize_steps,
            skip_full_coverage_check=skip_full_coverage_check,
            max_connections=self.max_connections,
            nodemanager_follow_cluster=nodemanager_follow_cluster,
            **connection_kwargs)
        self.initialized = False

        self.connections = {}
        self.connection_kwargs = connection_kwargs
        self.connection_kwargs['readonly'] = readonly
        self.readonly = readonly
        self.reset()

        if "stream_timeout" not in self.connection_kwargs:
            self.connection_kwargs[
                "stream_timeout"] = ClusterConnectionPool.RedisClusterDefaultTimeout

    def __repr__(self):
        """
        Return a string with all unique ip:port combinations that this pool is connected to.
        """
        return "{0}<{1}>".format(
            type(self).__name__, ", ".join([
                self.connection_class.description.format(**node)
                for node in self.nodes.startup_nodes
            ]))

    async def initialize(self):
        if not self.initialized:
            await self.nodes.initialize()
            self.initialized = True

    def reset(self):
        """
        Resets the connection pool back to a clean state.
        """
        self.pid = os.getpid()
        self._created_connections = 0
        self._created_connections_per_node = {}  # Dict(Node, Int)
        self._available_connections = {}  # Dict(Node, List)
        self._in_use_connections = {}  # Dict(Node, Set)
        self._check_lock = threading.Lock()
        self.initialized = False

    def _checkpid(self):
        """
        """
        if self.pid != os.getpid():
            with self._check_lock:
                if self.pid == os.getpid():
                    # another thread already did the work while we waited
                    # on the lockself.
                    return
                self.disconnect()
                self.reset()

    def get_connection(self, command_name, *keys, **options):
        # Only pubsub command/connection should be allowed here
        if command_name != "pubsub":
            raise RedisClusterException(
                "Only 'pubsub' commands can use get_connection()")

        channel = options.pop('channel', None)

        if not channel:
            return self.get_random_connection()

        slot = self.nodes.keyslot(channel)
        node = self.get_master_node_by_slot(slot)

        self._checkpid()

        try:
            connection = self._available_connections.get(node["name"],
                                                         []).pop()
        except IndexError:
            connection = self.make_connection(node)

        if node['name'] not in self._in_use_connections:
            self._in_use_connections[node['name']] = set()

        self._in_use_connections[node['name']].add(connection)

        return connection

    def make_connection(self, node):
        """
        Create a new connection
        """
        if self.count_all_num_connections(node) >= self.max_connections:
            if self.max_connections_per_node:
                raise RedisClusterException(
                    "Too many connection ({0}) for node: {1}".format(
                        self.count_all_num_connections(node), node['name']))

            raise RedisClusterException("Too many connections")

        self._created_connections_per_node.setdefault(node['name'], 0)
        self._created_connections_per_node[node['name']] += 1
        connection = self.connection_class(host=node["host"],
                                           port=node["port"],
                                           **self.connection_kwargs)

        # Must store node in the connection to make it eaiser to track
        connection.node = node

        return connection

    def release(self, connection):
        """
        Releases the connection back to the pool
        """
        self._checkpid()
        if connection.pid != self.pid:
            return

        # Remove the current connection from _in_use_connection and add it back to the available pool
        # There is cases where the connection is to be removed but it will not exist and there
        # must be a safe way to remove
        i_c = self._in_use_connections.get(connection.node["name"], set())
        if connection in i_c:
            i_c.remove(connection)
        else:
            pass
        self._available_connections.setdefault(connection.node["name"],
                                               []).append(connection)

    def disconnect(self):
        """
        Nothing that requires any overwrite.
        """
        all_conns = chain(
            self._available_connections.values(),
            self._in_use_connections.values(),
        )

        for node_connections in all_conns:
            for connection in node_connections:
                connection.disconnect()

    def count_all_num_connections(self, node):
        """
        """
        if self.max_connections_per_node:
            return self._created_connections_per_node.get(node['name'], 0)

        return sum([i for i in self._created_connections_per_node.values()])

    def get_random_connection(self):
        """
        Open new connection to random redis server.
        """
        if self._available_connections:
            return random.choice(self._available_connections)
        else:
            for node in self.nodes.random_startup_node_iter():
                connection = self.get_connection_by_node(node)

                if connection:
                    return connection

        raise Exception("Cant reach a single startup node.")

    def get_connection_by_key(self, key):
        """
        """
        if not key:
            raise RedisClusterException(
                "No way to dispatch this command to Redis Cluster.")

        return self.get_connection_by_slot(self.nodes.keyslot(key))

    def get_connection_by_slot(self, slot):
        """
        Determine what server a specific slot belongs to and return a redis object that is connected
        """
        self._checkpid()

        try:
            return self.get_connection_by_node(self.get_node_by_slot(slot))
        except KeyError:
            return self.get_random_connection()

    def get_connection_by_node(self, node):
        """
        get a connection by node
        """
        self._checkpid()
        self.nodes.set_node_name(node)

        try:
            # Try to get connection from existing pool
            connection = self._available_connections.get(node["name"],
                                                         []).pop()
        except IndexError:
            connection = self.make_connection(node)

        self._in_use_connections.setdefault(node["name"],
                                            set()).add(connection)

        return connection

    def get_master_node_by_slot(self, slot):
        return self.nodes.slots[slot][0]

    def get_node_by_slot(self, slot):
        if self.readonly:
            return random.choice(self.nodes.slots[slot])
        return self.get_master_node_by_slot(slot)
def test_keyslot():
    """
    Test that method will compute correct key in all supported cases
    """
    n = NodeManager([{}])

    assert n.keyslot("foo") == 12182
    assert n.keyslot("{foo}bar") == 12182
    assert n.keyslot("{foo}") == 12182
    assert n.keyslot(1337) == 4314

    assert n.keyslot(125) == n.keyslot(b"125")
    assert n.keyslot(125) == n.keyslot("\x31\x32\x35")
    assert n.keyslot("大奖") == n.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96")
    assert n.keyslot(u"大奖") == n.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96")
    assert n.keyslot(1337.1234) == n.keyslot("1337.1234")
    assert n.keyslot(1337) == n.keyslot("1337")
    assert n.keyslot(b"abc") == n.keyslot("abc")
    assert n.keyslot("abc") == n.keyslot(str("abc"))
    assert n.keyslot(str("abc")) == n.keyslot(b"abc")