Beispiel #1
0
 def cmd_ping(self, command, args, now, connection):
     if len(args) == 0:
         return RedisString('PONG')
     elif len(args) == 1:
         return RedisString(args[0])
     else:
         raise WrongNumberOfArgumentsError(command)
Beispiel #2
0
    def cmd_select(self, command, args, now, connection):
        if len(args) != 1:
            raise WrongNumberOfArgumentsError(command)

        db_index = ensure_int(args[0])
        if db_index < 0 or db_index > len(self.storage.dbs):
            raise RedisError('invalid DB index')

        connection.db_index = db_index
        return RedisString(b'OK')
Beispiel #3
0
    def cmd_mget(self, command, args, now, connection):
        if len(args) == 0:
            raise WrongNumberOfArgumentsError(command)

        db = self._get_db(connection)

        result = [db.get_active_key(key, now) for key in args]
        return RedisArray([
            RedisString(key) if key is not None else RedisNullBulkString()
            for key in result
        ])
Beispiel #4
0
    def cmd_get(self, command, args, now, connection):
        if len(args) != 1:
            raise WrongNumberOfArgumentsError(command)

        db = self._get_db(connection)

        key = args[0]
        value = db.get_active_key(key, now)
        if value is None:
            return RedisNullBulkString()
        return RedisBulkString(value)
Beispiel #5
0
    def cmd_mset(self, command, args, now, connection):
        if len(args) < 2:
            raise WrongNumberOfArgumentsError(command)
        if len(args) % 2 != 0:
            raise WrongNumberOfArgumentsError(command)

        db = self._get_db(connection)

        i = 0
        while i < len(args):
            key = args[i]
            value = args[i + 1]

            logger.info('Setting key: db=%s, key=%s, value=%s', db.index, key,
                        value)
            db.kv[key] = value
            db.remove_key_from_expiration(key)

            i += 2

        return RedisString(b'OK')
Beispiel #6
0
    def cmd_pubsub(self, command, args, now, connection):
        if len(args) < 1:
            raise WrongNumberOfArgumentsError(command)

        subcommand = args[0].lower()
        if subcommand == b'channels':
            # Return the active channels, matching the optional pattern.
            pattern_re = re.compile('^.*$')
            if len(args) == 2:
                pattern_re = build_re_from_pattern(args[1].decode())
            elif len(args) > 2:
                raise WrongNumberOfArgumentsError(command)

            result = [
                RedisBulkString(channel) for channel, connections in
                self._channel_subscriptions.items()
                if pattern_re.match(channel.decode()) and len(connections) > 0
            ]
            return RedisArray(result)
        elif subcommand == b'numsub':
            # Return the number of subscribers for each queried channel.
            channels = args[1:]
            result = []
            for channel in channels:
                result.extend([
                    RedisBulkString(channel),
                    RedisInteger(len(self._channel_subscriptions[channel])),
                ])
            return RedisArray(result)
        elif subcommand == b'numpat':
            # Return the number pattern subscriptions.
            count = sum(
                map(lambda x: len(x),
                    self._connection_pattern_channel_map.keys()))
            return RedisInteger(count)
        else:
            raise RedisError(
                'Unknown PUBSUB subcommand or wrong number of arguments '
                'for \'%s\'' % subcommand.decode(), )
Beispiel #7
0
    def cmd_del(self, command, args, now, connection):
        if len(args) < 1:
            raise WrongNumberOfArgumentsError(command)

        db = self._get_db(connection)
        count = 0
        for key in args:
            if key in db.kv and db.is_key_active(key, now):
                logger.info('Removing key: db=%s, key=%s', db.index, key)
                del db.kv[key]
                db.remove_key_from_expiration(key)
                count += 1

        return RedisInteger(count)
Beispiel #8
0
    def cmd_publish(self, command, args, now, connection):
        if len(args) != 2:
            raise WrongNumberOfArgumentsError(command)

        channel = args[0]
        message = args[1]

        # All connections subscribing to the channel.
        channel_connections = set(self._channel_subscriptions[channel])
        # All connections subscribing to a pattern matching the channel (may contain duplicates).
        pattern_connection_tuples = []
        for c, patterns in self._connection_pattern_channel_map.items():
            for pattern in patterns:
                if build_re_from_pattern(pattern).match(channel.decode()):
                    pattern_connection_tuples.append((pattern, c))

        if len(channel_connections) >= 0:
            data = RedisArray([
                RedisBulkString(b'message'),
                RedisBulkString(channel),
                RedisBulkString(message)
            ])
            logger.info(
                'Triggering sending message to subscribed connections: '
                'channel=%s, data=%s, num_connections=%s, connections=%s',
                channel, data, len(channel_connections), channel_connections)

            for c in channel_connections:
                # TODO: This might result in out-of-order messages. Do it with a queue and
                # connection-specific job instead?
                self._write_pool.spawn(c.write, data=data)

        if len(pattern_connection_tuples) >= 0:
            for pattern, c in pattern_connection_tuples:
                data = RedisArray([
                    RedisBulkString(b'pmessage'),
                    RedisBulkString(pattern),
                    RedisBulkString(message)
                ])
                logger.info(
                    'Triggering sending message to subscribed connections matching pattern: '
                    'channel=%s, pattern=%s, data=%s, connection=%s', channel,
                    pattern, data, c)

                self._write_pool.spawn(c.write, data=data)

        count = len(channel_connections) + len(pattern_connection_tuples)
        return RedisInteger(count)
Beispiel #9
0
    def cmd_keys(self, command, args, now, connection):
        if len(args) != 1:
            raise WrongNumberOfArgumentsError(command)

        db = self._get_db(connection)

        result = []
        pattern_re = build_re_from_pattern(args[0])
        logger.info('Will filter keys by pattern: re=%s', pattern_re.pattern)
        for key in db.kv.keys():
            if not pattern_re.fullmatch(key.decode()):
                continue
            if not db.is_key_active(key, now):
                continue
            result.append(key)

        return RedisArray([RedisString(key) for key in result])
Beispiel #10
0
    def cmd_psubscribe(self, command, args, now, connection):
        if len(args) == 0:
            raise WrongNumberOfArgumentsError(command)

        self._ensure_connection(connection)

        responses = []
        for pattern in args:
            self._connection_pattern_channel_map[connection].add(pattern)

            responses.append(
                RedisArray([
                    RedisBulkString(b'psubscribe'),
                    RedisBulkString(pattern),
                    RedisInteger(
                        self._current_connection_subscriptions(connection))
                ]))

        return RedisMultipleResponses(responses)
Beispiel #11
0
    def cmd_subscribe(self, command, args: List[bytes], now, connection):
        if len(args) == 0:
            raise WrongNumberOfArgumentsError(command)

        self._ensure_connection(connection)

        responses = []
        for channel in args:
            self._channel_subscriptions[channel].add(connection)
            self._connection_channel_map[connection].add(channel)

            responses.append(
                RedisArray([
                    RedisBulkString(b'subscribe'),
                    RedisBulkString(channel),
                    RedisInteger(
                        self._current_connection_subscriptions(connection))
                ]))

        connection.state = 'pubsub'

        return RedisMultipleResponses(responses)
Beispiel #12
0
def test_wrong_number_of_arguments_error():
    e = WrongNumberOfArgumentsError(b'dummy')
    assert e.to_resp(
    ) == b'-ERR wrong number of arguments for \'dummy\' command\r\n'
Beispiel #13
0
    def cmd_set(self, command, args, now, connection):
        if len(args) < 2:
            raise WrongNumberOfArgumentsError(command)

        db = self._get_db(connection)

        key = args[0]
        value = args[1]

        nx = False
        xx = False
        ex = None
        px = None

        if len(args) > 2:
            i = 2
            while i < len(args):
                if args[i] == b'nx':
                    nx = True
                elif args[i] == b'xx':
                    xx = True
                elif args[i] == b'ex':
                    if i >= len(args):
                        raise CommandSyntaxError()
                    ex = ensure_int(args[i + 1])
                    i += 1
                elif args[i] == b'px':
                    if i >= len(args):
                        raise CommandSyntaxError()
                    px = ensure_int(args[i + 1])
                    i += 1
                i += 1

        if xx and nx:
            raise CommandSyntaxError()

        if ex and px:
            raise CommandSyntaxError()

        key_exists = db.get_active_key(key, now) is not None
        if key_exists and nx:
            logger.info(
                'Not setting key, exists and nx is specified: db=%s, key=%s',
                db.index, key)
            return RedisNullBulkString()
        if not key_exists and xx:
            logger.info(
                'Not setting key, doesnt exist and xx is specified: db=%s, key=%s',
                db.index,
                key,
            )
            return RedisNullBulkString()

        logger.info('Setting key: db=%s, key=%s, value=%s', db.index, key,
                    value)
        db.kv[key] = value

        if px:
            db.set_expiry_for_key(key, px, now)
        elif ex:
            db.set_expiry_for_key(key, ex * 1000, now)
        else:
            db.remove_key_from_expiration(key)

        return RedisString(b'OK')