def oget( self, key: str, default: Union[T, Type[Exception]] = None, expect: ExpectFunc[U] = None ) -> Union[Optional[T], T, U, Optional[U]]: """Return the object at *key*. If *key* does not exist, *default* is returned. If *default* is an :exc:`Exception`, it is raised instead. The object type can be narrowed with *expect*. """ # pylint: disable=function-redefined,missing-docstring; overload object = self._cache.get(key) if self.caching else None if object is None: value = cast(Optional[bytes], self.get(key)) if value is not None: if not value.startswith(b'{'): raise ResponseError() try: # loads() actually returns Union[T, Dict[str, object]], but as T may be dict # there is no way to eliminate it here object = cast( T, json.loads(value.decode(), object_hook=self.decode)) except ValueError as e: raise ResponseError() from e if self.caching: self._cache[key] = object if object is None: if isinstance(default, type) and issubclass( default, Exception): # type: ignore raise cast(Exception, default(key)) object = cast(Optional[T], default) return expect(object) if expect and object is not None else object
def rename(self, src, dst): """ Rename key ``src`` to ``dst`` Cluster impl: This operation is no longer atomic because each key must be querried then set in separate calls because they maybe will change cluster node """ if src == dst: raise ResponseError("source and destination objects are the same") data = self.dump(src) if data is None: raise ResponseError("no such key") ttl = self.pttl(src) if ttl is None or ttl < 1: ttl = 0 self.delete(dst) self.restore(dst, ttl, data) self.delete(src) return True
def ask_redirect_effect(connection, command_name, **options): def ok_response(connection, command_name, **options): assert connection.host == "127.0.0.1" assert connection.port == 7001 return "MOCK_OK" m.side_effect = ok_response resp = ResponseError() resp.message = "ASK 1337 127.0.0.1:7001" raise resp
def moved_redirect_effect(connection, command_name, **options): def ok_response(connection, command_name, **options): assert connection.host == "127.0.0.1" assert connection.port == 7002 return "MOCK_OK" m.side_effect = ok_response resp = ResponseError() resp.message = "MOVED 12182 127.0.0.1:7002" raise resp
def side_effect(self, *args, **kwargs): def ok_call(self, *args, **kwargs): assert self.port == 7007 return "OK" parse_response_mock.side_effect = ok_call resp = ResponseError() resp.args = ('CLUSTERDOWN The cluster is down. Use CLUSTER INFO for more information',) resp.message = 'CLUSTERDOWN The cluster is down. Use CLUSTER INFO for more information' raise resp
def test_ask_exception_handling(r): """ Test that `handle_cluster_command_exception` deals with ASK error correctly. """ resp = ResponseError() resp.message = "ASK 1337 127.0.0.1:7000" assert r.handle_cluster_command_exception(resp) == { "name": "127.0.0.1:7000", "method": "ask", }
def parse_error(self, response): "Parse an error response" error_code = response.split(' ')[0] if error_code in self.EXCEPTION_CLASSES: response = response[len(error_code) + 1:] exception_class = self.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): for reason, inner_exception_class in exception_class.items(): if reason in response: return inner_exception_class(response) return ResponseError(response) return exception_class(response) return ResponseError(response)
def side_effect(self, *args, **kwargs): def ok_call(self, *args, **kwargs): assert self.port == 7007 return "OK" parse_response_mock.side_effect = ok_call resp = ResponseError() resp.args = ( 'CLUSTERDOWN The cluster is down. Use CLUSTER INFO for more information', ) resp.message = 'CLUSTERDOWN The cluster is down. Use CLUSTER INFO for more information' raise resp
def test_moved_exception_handling(r): """ Test that `handle_cluster_command_exception` deals with MOVED error correctly. """ resp = ResponseError() resp.message = "MOVED 1337 127.0.0.1:7000" r.handle_cluster_command_exception(resp) assert r.refresh_table_asap is True assert r.connection_pool.nodes.slots[1337] == { "host": "127.0.0.1", "port": 7000, "name": "127.0.0.1:7000", "server_type": "master", }
def parse_error(self, response): "Parse an error response" error_code = response.split(' ')[0] if error_code in self.EXCEPTION_CLASSES: response = response[len(error_code) + 1:] return self.EXCEPTION_CLASSES[error_code](response) return ResponseError(response)
def _execute_transaction(self, commands): conn = self.connection_pool.get_connection('MULTI', self.shard_hint) try: all_cmds = ''.join( starmap(conn.pack_command, [args for args, options in commands])) conn.send_packed_command(all_cmds) # we don't care about the multi/exec any longer commands = commands[1:-1] # parse off the response for MULTI and all commands prior to EXEC. # the only data we care about is the response the EXEC # which is the last command for i in range(len(commands) + 1): _ = self.parse_response(conn, '_') # parse the EXEC. response = self.parse_response(conn, '_') if response is None: raise WatchError("Watched variable changed.") if len(response) != len(commands): raise ResponseError("Wrong number of response items from " "pipeline execution") # We have to run response callbacks manually data = [] for r, cmd in izip(response, commands): if not isinstance(r, Exception): args, options = cmd command_name = args[0] if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) data.append(r) return data finally: self.connection_pool.release(conn)
def perform_execute_pipeline(pipe): if not test._calls: e = ResponseError("ASK {0} 127.0.0.1:7003".format(r.keyslot('foo'))) test._calls.append({'exception': e}) return [e, e] result = pipe.execute(raise_on_error=False) test._calls.append({'result': result}) return result
def read_response(self, command_name, catch_errors): response = self.read()[:-2] # strip last two characters (\r\n) if not response: self.disconnect() raise ConnectionError("Socket closed on remote end") # server returned a null value if response in ('$-1', '*-1'): return None byte, response = response[0], response[1:] # server returned an error if byte == '-': if response.startswith('ERR '): response = response[4:] if response.startswith('LOADING '): # If we're loading the dataset into memory, kill the socket # so we re-initialize (and re-SELECT) next time. self.disconnect() response = response[8:] raise ResponseError(response) # single value elif byte == '+': return response # int value elif byte == ':': return long(response) # bulk response elif byte == '$': length = int(response) if length == -1: return None response = length and self.read(length) or '' self.read(2) # read the \r\n delimiter return response # multi-bulk response elif byte == '*': length = int(response) if length == -1: return None if not catch_errors: return [ self.read_response(command_name, catch_errors) for i in range(length) ] else: # for pipelines, we need to read everything, # including response errors. otherwise we'd # completely mess up the receive buffer data = [] for i in range(length): try: data.append( self.read_response(command_name, catch_errors)) except Exception, e: data.append(e) return data
def rename(self, src, dst, replace=False): """ Rename key ``src`` to ``dst`` Cluster impl: If the src and dsst keys is in the same slot then send a plain RENAME command to that node to do the rename inside the server. If the keys is in crossslots then use the client side implementation as fallback method. In this case this operation is no longer atomic as the key is dumped and posted back to the server through the client. """ if src == dst: raise ResponseError("source and destination objects are the same") # # Optimization where if both keys is in the same slot then we can use the # plain upstream rename method. # src_slot = self.connection_pool.nodes.keyslot(src) dst_slot = self.connection_pool.nodes.keyslot(dst) if src_slot == dst_slot: return self.execute_command('RENAME', src, dst) # # To provide cross slot support we implement rename by doing the internal command # redis server runs but in the client instead. # data = self.dump(src) if data is None: raise ResponseError("no such key") ttl = self.pttl(src) if ttl is None or ttl < 1: ttl = 0 self.delete(dst) self.restore(dst, ttl, data, replace) self.delete(src) return True
def parse_error(self, response): "Parse an error response" error_code = response.split(' ')[0] if error_code in self.EXCEPTION_CLASSES: response = response[len(error_code) + 1:] exception_class = self.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) return exception_class(response) return ResponseError(response)
def read_response(self, command_name, catch_errors): response = self.read().strip() # strip last two characters (\r\n) if not response: self.disconnect() raise ConnectionError("Socket closed on remote end") # server returned a null value if response in ('$-1', '*-1'): return None byte, response = response[0], response[1:] # server returned an error if byte == '-': if response.startswith('ERR '): response = response[4:] raise ResponseError(response) # single value elif byte == '+': return response # int value elif byte == ':': return long(response) # bulk response elif byte == '$': length = int(response) if length == -1: return None response = length and self.read(length) or '' self.read(2) # read the \r\n delimiter return response # multi-bulk response elif byte == '*': length = int(response) if length == -1: return None if not catch_errors: return [ self.read_response(command_name, catch_errors) for i in range(length) ] else: # for pipelines, we need to read everything, # including response errors. otherwise we'd # completely mess up the receive buffer data = [] for i in range(length): try: data.append( self.read_response(command_name, catch_errors)) except Exception: e = sys.exc_info()[1] data.append(e) return data raise InvalidResponse("Unknown response type for: %s" % command_name)
def _execute_transaction(self, connection, commands, raise_on_error): cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})]) all_cmds = connection.pack_commands([args for args, _ in cmds]) yield connection.send_packed_command(all_cmds) errors = [] # parse MULTI try: yield self.parse_response(connection, '_') except ResponseError: errors.append((0, sys.exc_info()[1])) # parse commands for i, command in enumerate(commands): try: yield self.parse_response(connection, '_') except ResponseError: ex = sys.exc_info()[1] self.annotate_exception(ex, i + 1, command[0]) errors.append((i, ex)) # parse EXEC try: response = yield self.parse_response(connection, '_') except ExecAbortError: if self.explicit_transaction: yield self.immediate_execute_command('DISCARD') if errors: raise errors[0][1] raise sys.exc_info()[1] if response is None: raise WatchError("Watched variable changed.") for i, e in errors: response.insert(i, e) if len(response) != len(commands): self.connection.disconnect() raise ResponseError("Wrong number of response items from " "pipeline execution") if raise_on_error: self.raise_first_error(commands, response) # We have to run response callbacks manually data = [] for r, cmd in izip(response, commands): if not isinstance(r, Exception): args, options = cmd command_name = args[0] if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) data.append(r) raise gen.Return(data)
def parse_error(self, response): "Parse an error response" # TODO 解析一个错误响应 error_code = response.split(' ')[0] # TODO 以空格分隔的第一个元素为错误码 # TODO 对已知的错误提示返回对应的错误对象 if error_code in self.EXCEPTION_CLASSES: response = response[len(error_code) + 1:] # TODO 错误码后面是错误提示 exception_class = self.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) return exception_class(response) return ResponseError(response) # TODO 不属于已知的错误码的,返回默认的
def _parse_response(self, command_name, catch_errors): conn = self.connection response = conn.read().strip() if not response: self.connection.disconnect() raise ConnectionError("Socket closed on remote end") # server returned a null value if response in ('$-1', '*-1'): return None byte, response = response[0], response[1:] # server returned an error if byte == '-': if response.startswith('ERR '): response = response[4:] raise ResponseError(response) # single value elif byte == '+': return response # int value elif byte == ':': return int(response) # bulk response elif byte == '$': length = int(response) if length == -1: return None response = length and conn.read(length) or '' conn.read(2) # read the \r\n delimiter return response # multi-bulk response elif byte == '*': length = int(response) if length == -1: return None if not catch_errors: return [ self._parse_response(command_name, catch_errors) for i in range(length) ] else: # for pipelines, we need to read everything, including response errors. # otherwise we'd completely mess up the receive buffer data = [] for i in range(length): try: data.append( self._parse_response(command_name, catch_errors)) except Exception, e: data.append(e) return data
def _execute_transaction(self, connection, commands): cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})]) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options]) connection.send_packed_command(all_cmds) errors = [] # parse off the response for MULTI # NOTE: we need to handle ResponseErrors here and continue # so that we read all the additional command messages from # the socket try: self.parse_response(connection, '_') except ResponseError: errors.append((0, sys.exc_info()[1])) # and all the other commands for i, command in enumerate(commands): if EMPTY_RESPONSE in command[1]: errors.append((i, command[1][EMPTY_RESPONSE])) else: try: self.parse_response(connection, '_') except ResponseError: ex = sys.exc_info()[1] self.annotate_exception(ex, i + 1, command[0]) errors.append((i, ex)) # parse the EXEC. try: response = self.parse_response(connection, '_') except ExecAbortError: if self.explicit_transaction: self.immediate_execute_command('DISCARD') if errors: raise errors[0][1] raise sys.exc_info()[1] if response is None: raise WatchError("Watched variable changed.") # put any parse errors into the response for i, e in errors: response.insert(i, e) if len(response) != len(commands): self.connection.disconnect() raise ResponseError( "Wrong number of response items from pipeline execution") return response # to be parsed, just for learning
def test_error_handler_unknown_message(self): name = 'foo' content = "" script = parse_script( name=name, content=content, ) exception = ResponseError("ERR Unknown error") with self.assertRaises(ResponseError) as error: with error_handler(script=script): raise exception self.assertIs(exception, error.exception)
def test_error_handler(self): name = 'foo' content = """ local a = 1; local b = 2; local c = 3; local d = 4; local e = 5; local f = 6; local g = 7; local h = 8; local i = 9; local j = 10; local k = 11; local l = 12; """ script = parse_script( name=name, content=content, ) exception = ResponseError( "ERR something is wrong: f_1234abc:11: my lua error", ) with self.assertRaises(ScriptError) as error: with error_handler(script=script): raise exception self.assertEqual( script, error.exception.script, ) self.assertEqual( 11, error.exception.line, ) self.assertEqual( 'my lua error', error.exception.lua_error, ) self.assertEqual( 'something is wrong', error.exception.message, )
def read_response(self): response = self.read() if not response: raise ConnectionError("Socket closed on remote end") byte, response = byte_to_chr(response[0]), response[1:] if byte not in ('-', '+', ':', '$', '*'): raise InvalidResponse("Protocol Error") # server returned an error if byte == '-': if nativestr(response).startswith('LOADING '): # if we're loading the dataset into memory, kill the socket # so we re-initialize (and re-SELECT) next time. raise ConnectionError("Redis is loading data into memory") # if the error starts with ERR, trim that off if nativestr(response).startswith('ERR '): response = response[4:] # *return*, not raise the exception class. if it is meant to be # raised, it will be at a higher level. return ResponseError(response) # single value elif byte == '+': pass # int value elif byte == ':': response = long(response) # bulk response elif byte == '$': length = int(response) if length == -1: return None response = self.read(length) # multi-bulk response elif byte == '*': length = int(response) if length == -1: return None response = [self.read_response() for i in xrange(length)] if isinstance(response, bytes) and self.encoding: response = response.decode(self.encoding) return response
def _parse_response(self, command_name): conn = self.connection response = conn.read().strip() if not response: self.connection.disconnect() raise ConnectionError("Socket closed on remote end") # server returned a null value if response in ('$-1', '*-1'): return None byte, response = response[0], response[1:] # server returned an error if byte == '-': if response.startswith('ERR '): response = response[4:] raise ResponseError(response) # single value elif byte == '+': return response # int value elif byte == ':': return int(response) # bulk response elif byte == '$': length = int(response) if length == -1: return None response = length and conn.read(length) or '' conn.read(2) # read the \r\n delimiter return response # multi-bulk response elif byte == '*': length = int(response) if length == -1: return None return [self._parse_response(command_name) for i in range(length)] raise InvalidResponse("Unknown response type for: %s" % command_name)
def _execute(self, commands): # build up all commands into a single request to increase network perf all_cmds = ''.join([c for _1, c, _2 in commands]) self.connection.send(all_cmds, self) # we only care about the last item in the response, which should be # the EXEC command for i in range(len(commands) - 1): _ = self.parse_response('_') # tell the response parse to catch errors and return them as # part of the response response = self.parse_response('_', catch_errors=True) # don't return the results of the MULTI or EXEC command commands = [(c[0], c[2]) for c in commands[1:-1]] if len(response) != len(commands): raise ResponseError("Wrong number of response items from " "pipline execution") # Run any callbacks for the commands run in the pipeline data = [] for r, cmd in zip(response, commands): if not isinstance(r, Exception): if cmd[0] in self.RESPONSE_CALLBACKS: r = self.RESPONSE_CALLBACKS[cmd[0]](r, **cmd[1]) data.append(r) return data
def execute_command_via_connection(r, *argv, **kwargs): # the first time this is called, simulate an ASK exception. # after that behave normally. # capture all the requests and responses. if not test.execute_command_calls: e = ResponseError("ASK 1 127.0.0.1:7003") test.execute_command_calls.append({'exception': e}) raise e try: result = execute_command_via_connection_original( r, *argv, **kwargs) test.execute_command_calls.append({ 'argv': argv, 'kwargs': kwargs, 'result': result }) return result except Exception as e: test.execute_command_calls.append({ 'argv': argv, 'kwargs': kwargs, 'exception': e }) raise e
def read_response(self): response = self.read() if not response: raise ConnectionError("Socket closed on remote end") byte, response = response[0], response[1:] # server returned an error if byte == '-': if response.startswith('ERR '): response = response[4:] return ResponseError(response) if response.startswith('LOADING '): # If we're loading the dataset into memory, kill the socket # so we re-initialize (and re-SELECT) next time. raise ConnectionError("Redis is loading data into memory") # single value elif byte == '+': return response # int value elif byte == ':': return long(response) # bulk response elif byte == '$': length = int(response) if length == -1: return None response = self.read(length) return response # multi-bulk response elif byte == '*': length = int(response) if length == -1: return None return [self.read_response() for i in xrange(length)] raise InvalidResponse("Protocol Error")
def bgsave(self): self._called.append('BGSAVE') if self.bgsave_raises_ResponseError: raise ResponseError()
def psetex(self, name, time_ms, value): if isinstance(time_ms, timedelta): time_ms = int(timedelta_total_seconds(time_ms) * 1000) if time_ms == 0: raise ResponseError("invalid expire time in SETEX") return self.set(name, value, px=time_ms)
def expireat(self, key, timestamp): if not isinstance(timestamp, int): raise ResponseError("value is not an integer or out of range") self.expiry[key] = timestamp
def mock_register(cls, redis): raise ResponseError()