예제 #1
0
    def oget(
            self,
            key: str,
            default: Union[T, Type[Exception]] = None,
            expect: ExpectFunc[U] = None
    ) -> Union[Optional[T], T, U, Optional[U]]:
        """Return the object at *key*.

        If *key* does not exist, *default* is returned. If *default* is an :exc:`Exception`, it is
        raised instead. The object type can be narrowed with *expect*.
        """
        # pylint: disable=function-redefined,missing-docstring; overload
        object = self._cache.get(key) if self.caching else None
        if object is None:
            value = cast(Optional[bytes], self.get(key))
            if value is not None:
                if not value.startswith(b'{'):
                    raise ResponseError()
                try:
                    # loads() actually returns Union[T, Dict[str, object]], but as T may be dict
                    # there is no way to eliminate it here
                    object = cast(
                        T, json.loads(value.decode(), object_hook=self.decode))
                except ValueError as e:
                    raise ResponseError() from e
                if self.caching:
                    self._cache[key] = object
        if object is None:
            if isinstance(default, type) and issubclass(
                    default, Exception):  # type: ignore
                raise cast(Exception, default(key))
            object = cast(Optional[T], default)
        return expect(object) if expect and object is not None else object
예제 #2
0
    def rename(self, src, dst):
        """
        Rename key ``src`` to ``dst``

        Cluster impl:
            This operation is no longer atomic because each key must be querried
            then set in separate calls because they maybe will change cluster node
        """
        if src == dst:
            raise ResponseError("source and destination objects are the same")

        data = self.dump(src)

        if data is None:
            raise ResponseError("no such key")

        ttl = self.pttl(src)

        if ttl is None or ttl < 1:
            ttl = 0

        self.delete(dst)
        self.restore(dst, ttl, data)
        self.delete(src)

        return True
예제 #3
0
    def ask_redirect_effect(connection, command_name, **options):
        def ok_response(connection, command_name, **options):
            assert connection.host == "127.0.0.1"
            assert connection.port == 7001

            return "MOCK_OK"
        m.side_effect = ok_response
        resp = ResponseError()
        resp.message = "ASK 1337 127.0.0.1:7001"
        raise resp
예제 #4
0
    def moved_redirect_effect(connection, command_name, **options):
        def ok_response(connection, command_name, **options):
            assert connection.host == "127.0.0.1"
            assert connection.port == 7002

            return "MOCK_OK"
        m.side_effect = ok_response
        resp = ResponseError()
        resp.message = "MOVED 12182 127.0.0.1:7002"
        raise resp
예제 #5
0
            def side_effect(self, *args, **kwargs):
                def ok_call(self, *args, **kwargs):
                    assert self.port == 7007
                    return "OK"
                parse_response_mock.side_effect = ok_call

                resp = ResponseError()
                resp.args = ('CLUSTERDOWN The cluster is down. Use CLUSTER INFO for more information',)
                resp.message = 'CLUSTERDOWN The cluster is down. Use CLUSTER INFO for more information'
                raise resp
def test_ask_exception_handling(r):
    """
    Test that `handle_cluster_command_exception` deals with ASK
    error correctly.
    """
    resp = ResponseError()
    resp.message = "ASK 1337 127.0.0.1:7000"
    assert r.handle_cluster_command_exception(resp) == {
        "name": "127.0.0.1:7000",
        "method": "ask",
    }
예제 #7
0
def test_ask_exception_handling(r):
    """
    Test that `handle_cluster_command_exception` deals with ASK
    error correctly.
    """
    resp = ResponseError()
    resp.message = "ASK 1337 127.0.0.1:7000"
    assert r.handle_cluster_command_exception(resp) == {
        "name": "127.0.0.1:7000",
        "method": "ask",
    }
    def moved_redirect_effect(connection, command_name, **options):
        def ok_response(connection, command_name, **options):
            assert connection.host == "127.0.0.1"
            assert connection.port == 7002

            return "MOCK_OK"

        m.side_effect = ok_response
        resp = ResponseError()
        resp.message = "MOVED 12182 127.0.0.1:7002"
        raise resp
    def ask_redirect_effect(connection, command_name, **options):
        def ok_response(connection, command_name, **options):
            assert connection.host == "127.0.0.1"
            assert connection.port == 7001

            return "MOCK_OK"

        m.side_effect = ok_response
        resp = ResponseError()
        resp.message = "ASK 1337 127.0.0.1:7001"
        raise resp
예제 #10
0
def parse_error(self, response):
    "Parse an error response"
    error_code = response.split(' ')[0]
    if error_code in self.EXCEPTION_CLASSES:
        response = response[len(error_code) + 1:]
        exception_class = self.EXCEPTION_CLASSES[error_code]
        if isinstance(exception_class, dict):
            for reason, inner_exception_class in exception_class.items():
                if reason in response:
                    return inner_exception_class(response)
            return ResponseError(response)
        return exception_class(response)
    return ResponseError(response)
            def side_effect(self, *args, **kwargs):
                def ok_call(self, *args, **kwargs):
                    assert self.port == 7007
                    return "OK"

                parse_response_mock.side_effect = ok_call

                resp = ResponseError()
                resp.args = (
                    'CLUSTERDOWN The cluster is down. Use CLUSTER INFO for more information',
                )
                resp.message = 'CLUSTERDOWN The cluster is down. Use CLUSTER INFO for more information'
                raise resp
예제 #12
0
def test_moved_exception_handling(r):
    """
    Test that `handle_cluster_command_exception` deals with MOVED
    error correctly.
    """
    resp = ResponseError()
    resp.message = "MOVED 1337 127.0.0.1:7000"
    r.handle_cluster_command_exception(resp)
    assert r.refresh_table_asap is True
    assert r.connection_pool.nodes.slots[1337] == {
        "host": "127.0.0.1",
        "port": 7000,
        "name": "127.0.0.1:7000",
        "server_type": "master",
    }
def test_moved_exception_handling(r):
    """
    Test that `handle_cluster_command_exception` deals with MOVED
    error correctly.
    """
    resp = ResponseError()
    resp.message = "MOVED 1337 127.0.0.1:7000"
    r.handle_cluster_command_exception(resp)
    assert r.refresh_table_asap is True
    assert r.connection_pool.nodes.slots[1337] == {
        "host": "127.0.0.1",
        "port": 7000,
        "name": "127.0.0.1:7000",
        "server_type": "master",
    }
예제 #14
0
 def parse_error(self, response):
     "Parse an error response"
     error_code = response.split(' ')[0]
     if error_code in self.EXCEPTION_CLASSES:
         response = response[len(error_code) + 1:]
         return self.EXCEPTION_CLASSES[error_code](response)
     return ResponseError(response)
예제 #15
0
    def _execute_transaction(self, commands):
        conn = self.connection_pool.get_connection('MULTI', self.shard_hint)
        try:
            all_cmds = ''.join(
                starmap(conn.pack_command,
                        [args for args, options in commands]))
            conn.send_packed_command(all_cmds)
            # we don't care about the multi/exec any longer
            commands = commands[1:-1]
            # parse off the response for MULTI and all commands prior to EXEC.
            # the only data we care about is the response the EXEC
            # which is the last command
            for i in range(len(commands) + 1):
                _ = self.parse_response(conn, '_')
            # parse the EXEC.
            response = self.parse_response(conn, '_')

            if response is None:
                raise WatchError("Watched variable changed.")

            if len(response) != len(commands):
                raise ResponseError("Wrong number of response items from "
                                    "pipeline execution")
            # We have to run response callbacks manually
            data = []
            for r, cmd in izip(response, commands):
                if not isinstance(r, Exception):
                    args, options = cmd
                    command_name = args[0]
                    if command_name in self.response_callbacks:
                        r = self.response_callbacks[command_name](r, **options)
                data.append(r)
            return data
        finally:
            self.connection_pool.release(conn)
예제 #16
0
        def perform_execute_pipeline(pipe):
            if not test._calls:

                e = ResponseError("ASK {0} 127.0.0.1:7003".format(r.keyslot('foo')))
                test._calls.append({'exception': e})
                return [e, e]
            result = pipe.execute(raise_on_error=False)
            test._calls.append({'result': result})
            return result
예제 #17
0
    def read_response(self, command_name, catch_errors):
        response = self.read()[:-2]  # strip last two characters (\r\n)
        if not response:
            self.disconnect()
            raise ConnectionError("Socket closed on remote end")

        # server returned a null value
        if response in ('$-1', '*-1'):
            return None
        byte, response = response[0], response[1:]

        # server returned an error
        if byte == '-':
            if response.startswith('ERR '):
                response = response[4:]
            if response.startswith('LOADING '):
                # If we're loading the dataset into memory, kill the socket
                # so we re-initialize (and re-SELECT) next time.
                self.disconnect()
                response = response[8:]
            raise ResponseError(response)
        # single value
        elif byte == '+':
            return response
        # int value
        elif byte == ':':
            return long(response)
        # bulk response
        elif byte == '$':
            length = int(response)
            if length == -1:
                return None
            response = length and self.read(length) or ''
            self.read(2)  # read the \r\n delimiter
            return response
        # multi-bulk response
        elif byte == '*':
            length = int(response)
            if length == -1:
                return None
            if not catch_errors:
                return [
                    self.read_response(command_name, catch_errors)
                    for i in range(length)
                ]
            else:
                # for pipelines, we need to read everything,
                # including response errors. otherwise we'd
                # completely mess up the receive buffer
                data = []
                for i in range(length):
                    try:
                        data.append(
                            self.read_response(command_name, catch_errors))
                    except Exception, e:
                        data.append(e)
                return data
예제 #18
0
    def rename(self, src, dst, replace=False):
        """
        Rename key ``src`` to ``dst``

        Cluster impl:
            If the src and dsst keys is in the same slot then send a plain RENAME
            command to that node to do the rename inside the server.

            If the keys is in crossslots then use the client side implementation
            as fallback method. In this case this operation is no longer atomic as
            the key is dumped and posted back to the server through the client.
        """
        if src == dst:
            raise ResponseError("source and destination objects are the same")

        #
        # Optimization where if both keys is in the same slot then we can use the
        # plain upstream rename method.
        #
        src_slot = self.connection_pool.nodes.keyslot(src)
        dst_slot = self.connection_pool.nodes.keyslot(dst)

        if src_slot == dst_slot:
            return self.execute_command('RENAME', src, dst)

        #
        # To provide cross slot support we implement rename by doing the internal command
        # redis server runs but in the client instead.
        #
        data = self.dump(src)

        if data is None:
            raise ResponseError("no such key")

        ttl = self.pttl(src)

        if ttl is None or ttl < 1:
            ttl = 0

        self.delete(dst)
        self.restore(dst, ttl, data, replace)
        self.delete(src)

        return True
예제 #19
0
 def parse_error(self, response):
     "Parse an error response"
     error_code = response.split(' ')[0]
     if error_code in self.EXCEPTION_CLASSES:
         response = response[len(error_code) + 1:]
         exception_class = self.EXCEPTION_CLASSES[error_code]
         if isinstance(exception_class, dict):
             exception_class = exception_class.get(response, ResponseError)
         return exception_class(response)
     return ResponseError(response)
예제 #20
0
    def read_response(self, command_name, catch_errors):
        response = self.read().strip()  # strip last two characters (\r\n)
        if not response:
            self.disconnect()
            raise ConnectionError("Socket closed on remote end")

        # server returned a null value
        if response in ('$-1', '*-1'):
            return None
        byte, response = response[0], response[1:]

        # server returned an error
        if byte == '-':
            if response.startswith('ERR '):
                response = response[4:]
            raise ResponseError(response)
        # single value
        elif byte == '+':
            return response
        # int value
        elif byte == ':':
            return long(response)
        # bulk response
        elif byte == '$':
            length = int(response)
            if length == -1:
                return None
            response = length and self.read(length) or ''
            self.read(2)  # read the \r\n delimiter
            return response
        # multi-bulk response
        elif byte == '*':
            length = int(response)
            if length == -1:
                return None
            if not catch_errors:
                return [
                    self.read_response(command_name, catch_errors)
                    for i in range(length)
                ]
            else:
                # for pipelines, we need to read everything,
                # including response errors. otherwise we'd
                # completely mess up the receive buffer
                data = []
                for i in range(length):
                    try:
                        data.append(
                            self.read_response(command_name, catch_errors))
                    except Exception:
                        e = sys.exc_info()[1]
                        data.append(e)
                return data

        raise InvalidResponse("Unknown response type for: %s" % command_name)
예제 #21
0
    def _execute_transaction(self, connection, commands, raise_on_error):
        cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})])
        all_cmds = connection.pack_commands([args for args, _ in cmds])
        yield connection.send_packed_command(all_cmds)
        errors = []
        # parse MULTI
        try:
            yield self.parse_response(connection, '_')
        except ResponseError:
            errors.append((0, sys.exc_info()[1]))

        # parse commands
        for i, command in enumerate(commands):
            try:
                yield self.parse_response(connection, '_')
            except ResponseError:
                ex = sys.exc_info()[1]
                self.annotate_exception(ex, i + 1, command[0])
                errors.append((i, ex))

        # parse EXEC
        try:
            response = yield self.parse_response(connection, '_')
        except ExecAbortError:
            if self.explicit_transaction:
                yield self.immediate_execute_command('DISCARD')
            if errors:
                raise errors[0][1]
            raise sys.exc_info()[1]

        if response is None:
            raise WatchError("Watched variable changed.")

        for i, e in errors:
            response.insert(i, e)

        if len(response) != len(commands):
            self.connection.disconnect()
            raise ResponseError("Wrong number of response items from "
                                "pipeline execution")

        if raise_on_error:
            self.raise_first_error(commands, response)

        # We have to run response callbacks manually
        data = []
        for r, cmd in izip(response, commands):
            if not isinstance(r, Exception):
                args, options = cmd
                command_name = args[0]
                if command_name in self.response_callbacks:
                    r = self.response_callbacks[command_name](r, **options)
            data.append(r)
        raise gen.Return(data)
예제 #22
0
    def parse_error(self, response):
        "Parse an error response"  # TODO 解析一个错误响应
        error_code = response.split(' ')[0]  # TODO 以空格分隔的第一个元素为错误码

        # TODO 对已知的错误提示返回对应的错误对象
        if error_code in self.EXCEPTION_CLASSES:
            response = response[len(error_code) + 1:]  # TODO 错误码后面是错误提示
            exception_class = self.EXCEPTION_CLASSES[error_code]
            if isinstance(exception_class, dict):
                exception_class = exception_class.get(response, ResponseError)
            return exception_class(response)
        return ResponseError(response)  # TODO 不属于已知的错误码的,返回默认的
예제 #23
0
파일: client.py 프로젝트: tcpavel/tweeql
    def _parse_response(self, command_name, catch_errors):
        conn = self.connection
        response = conn.read().strip()
        if not response:
            self.connection.disconnect()
            raise ConnectionError("Socket closed on remote end")

        # server returned a null value
        if response in ('$-1', '*-1'):
            return None
        byte, response = response[0], response[1:]

        # server returned an error
        if byte == '-':
            if response.startswith('ERR '):
                response = response[4:]
            raise ResponseError(response)
        # single value
        elif byte == '+':
            return response
        # int value
        elif byte == ':':
            return int(response)
        # bulk response
        elif byte == '$':
            length = int(response)
            if length == -1:
                return None
            response = length and conn.read(length) or ''
            conn.read(2)  # read the \r\n delimiter
            return response
        # multi-bulk response
        elif byte == '*':
            length = int(response)
            if length == -1:
                return None
            if not catch_errors:
                return [
                    self._parse_response(command_name, catch_errors)
                    for i in range(length)
                ]
            else:
                # for pipelines, we need to read everything, including response errors.
                # otherwise we'd completely mess up the receive buffer
                data = []
                for i in range(length):
                    try:
                        data.append(
                            self._parse_response(command_name, catch_errors))
                    except Exception, e:
                        data.append(e)
                return data
예제 #24
0
    def _execute_transaction(self, connection, commands):
        cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})])
        all_cmds = connection.pack_commands(
            [args for args, options in cmds if EMPTY_RESPONSE not in options])
        connection.send_packed_command(all_cmds)
        errors = []
        # parse off the response for MULTI
        # NOTE: we need to handle ResponseErrors here and continue
        # so that we read all the additional command messages from
        # the socket
        try:
            self.parse_response(connection, '_')
        except ResponseError:
            errors.append((0, sys.exc_info()[1]))

        # and all the other commands
        for i, command in enumerate(commands):
            if EMPTY_RESPONSE in command[1]:
                errors.append((i, command[1][EMPTY_RESPONSE]))
            else:
                try:
                    self.parse_response(connection, '_')
                except ResponseError:
                    ex = sys.exc_info()[1]
                    self.annotate_exception(ex, i + 1, command[0])
                    errors.append((i, ex))

        # parse the EXEC.
        try:
            response = self.parse_response(connection, '_')
        except ExecAbortError:
            if self.explicit_transaction:
                self.immediate_execute_command('DISCARD')
            if errors:
                raise errors[0][1]
            raise sys.exc_info()[1]

        if response is None:
            raise WatchError("Watched variable changed.")

        # put any parse errors into the response
        for i, e in errors:
            response.insert(i, e)

        if len(response) != len(commands):
            self.connection.disconnect()
            raise ResponseError(
                "Wrong number of response items from pipeline execution")
        return response  # to be parsed, just for learning
예제 #25
0
    def test_error_handler_unknown_message(self):
        name = 'foo'
        content = ""

        script = parse_script(
            name=name,
            content=content,
        )
        exception = ResponseError("ERR Unknown error")

        with self.assertRaises(ResponseError) as error:
            with error_handler(script=script):
                raise exception

        self.assertIs(exception, error.exception)
예제 #26
0
    def test_error_handler(self):
        name = 'foo'
        content = """
local a = 1;
local b = 2;
local c = 3;
local d = 4;
local e = 5;
local f = 6;
local g = 7;
local h = 8;
local i = 9;
local j = 10;
local k = 11;
local l = 12;
"""

        script = parse_script(
            name=name,
            content=content,
        )
        exception = ResponseError(
            "ERR something is wrong: f_1234abc:11: my lua error",
        )

        with self.assertRaises(ScriptError) as error:
            with error_handler(script=script):
                raise exception

        self.assertEqual(
            script,
            error.exception.script,
        )
        self.assertEqual(
            11,
            error.exception.line,
        )
        self.assertEqual(
            'my lua error',
            error.exception.lua_error,
        )
        self.assertEqual(
            'something is wrong',
            error.exception.message,
        )
예제 #27
0
    def read_response(self):
        response = self.read()
        if not response:
            raise ConnectionError("Socket closed on remote end")

        byte, response = byte_to_chr(response[0]), response[1:]

        if byte not in ('-', '+', ':', '$', '*'):
            raise InvalidResponse("Protocol Error")

        # server returned an error
        if byte == '-':
            if nativestr(response).startswith('LOADING '):
                # if we're loading the dataset into memory, kill the socket
                # so we re-initialize (and re-SELECT) next time.
                raise ConnectionError("Redis is loading data into memory")
            # if the error starts with ERR, trim that off
            if nativestr(response).startswith('ERR '):
                response = response[4:]
            # *return*, not raise the exception class. if it is meant to be
            # raised, it will be at a higher level.
            return ResponseError(response)
        # single value
        elif byte == '+':
            pass
        # int value
        elif byte == ':':
            response = long(response)
        # bulk response
        elif byte == '$':
            length = int(response)
            if length == -1:
                return None
            response = self.read(length)
        # multi-bulk response
        elif byte == '*':
            length = int(response)
            if length == -1:
                return None
            response = [self.read_response() for i in xrange(length)]
        if isinstance(response, bytes) and self.encoding:
            response = response.decode(self.encoding)
        return response
예제 #28
0
    def _parse_response(self, command_name):
        conn = self.connection
        response = conn.read().strip()
        if not response:
            self.connection.disconnect()
            raise ConnectionError("Socket closed on remote end")

        # server returned a null value
        if response in ('$-1', '*-1'):
            return None
        byte, response = response[0], response[1:]

        # server returned an error
        if byte == '-':
            if response.startswith('ERR '):
                response = response[4:]
            raise ResponseError(response)
        # single value
        elif byte == '+':
            return response
        # int value
        elif byte == ':':
            return int(response)
        # bulk response
        elif byte == '$':
            length = int(response)
            if length == -1:
                return None
            response = length and conn.read(length) or ''
            conn.read(2)  # read the \r\n delimiter
            return response
        # multi-bulk response
        elif byte == '*':
            length = int(response)
            if length == -1:
                return None
            return [self._parse_response(command_name) for i in range(length)]

        raise InvalidResponse("Unknown response type for: %s" % command_name)
예제 #29
0
파일: client.py 프로젝트: tcpavel/tweeql
 def _execute(self, commands):
     # build up all commands into a single request to increase network perf
     all_cmds = ''.join([c for _1, c, _2 in commands])
     self.connection.send(all_cmds, self)
     # we only care about the last item in the response, which should be
     # the EXEC command
     for i in range(len(commands) - 1):
         _ = self.parse_response('_')
     # tell the response parse to catch errors and return them as
     # part of the response
     response = self.parse_response('_', catch_errors=True)
     # don't return the results of the MULTI or EXEC command
     commands = [(c[0], c[2]) for c in commands[1:-1]]
     if len(response) != len(commands):
         raise ResponseError("Wrong number of response items from "
                             "pipline execution")
     # Run any callbacks for the commands run in the pipeline
     data = []
     for r, cmd in zip(response, commands):
         if not isinstance(r, Exception):
             if cmd[0] in self.RESPONSE_CALLBACKS:
                 r = self.RESPONSE_CALLBACKS[cmd[0]](r, **cmd[1])
         data.append(r)
     return data
예제 #30
0
 def execute_command_via_connection(r, *argv, **kwargs):
     # the first time this is called, simulate an ASK exception.
     # after that behave normally.
     # capture all the requests and responses.
     if not test.execute_command_calls:
         e = ResponseError("ASK 1 127.0.0.1:7003")
         test.execute_command_calls.append({'exception': e})
         raise e
     try:
         result = execute_command_via_connection_original(
             r, *argv, **kwargs)
         test.execute_command_calls.append({
             'argv': argv,
             'kwargs': kwargs,
             'result': result
         })
         return result
     except Exception as e:
         test.execute_command_calls.append({
             'argv': argv,
             'kwargs': kwargs,
             'exception': e
         })
         raise e
예제 #31
0
    def read_response(self):
        response = self.read()
        if not response:
            raise ConnectionError("Socket closed on remote end")

        byte, response = response[0], response[1:]

        # server returned an error
        if byte == '-':
            if response.startswith('ERR '):
                response = response[4:]
                return ResponseError(response)
            if response.startswith('LOADING '):
                # If we're loading the dataset into memory, kill the socket
                # so we re-initialize (and re-SELECT) next time.
                raise ConnectionError("Redis is loading data into memory")
        # single value
        elif byte == '+':
            return response
        # int value
        elif byte == ':':
            return long(response)
        # bulk response
        elif byte == '$':
            length = int(response)
            if length == -1:
                return None
            response = self.read(length)
            return response
        # multi-bulk response
        elif byte == '*':
            length = int(response)
            if length == -1:
                return None
            return [self.read_response() for i in xrange(length)]
        raise InvalidResponse("Protocol Error")
예제 #32
0
 def bgsave(self):
     self._called.append('BGSAVE')
     if self.bgsave_raises_ResponseError:
         raise ResponseError()
예제 #33
0
 def psetex(self, name, time_ms, value):
     if isinstance(time_ms, timedelta):
         time_ms = int(timedelta_total_seconds(time_ms) * 1000)
     if time_ms == 0:
         raise ResponseError("invalid expire time in SETEX")
     return self.set(name, value, px=time_ms)
예제 #34
0
 def expireat(self, key, timestamp):
     if not isinstance(timestamp, int):
         raise ResponseError("value is not an integer or out of range")
     self.expiry[key] = timestamp
예제 #35
0
 def mock_register(cls, redis):
     raise ResponseError()