async def call_rpc(self, rpc_message: RpcMessage, options: dict, bus_client: "BusClient"): queue_key = f"{rpc_message.api_name}:rpc_queue" expiry_key = f"rpc_expiry_key:{rpc_message.id}" logger.debug( LBullets( L("Enqueuing message {} in Redis list {}", Bold(rpc_message), Bold(queue_key)), items=dict(**rpc_message.get_metadata(), kwargs=rpc_message.get_kwargs()), )) start_time = time.time() for try_number in range(3, 0, -1): last_try = try_number == 1 try: await self._call_rpc(rpc_message, queue_key, expiry_key) return except (PipelineError, ConnectionClosedError, ConnectionResetError): if not last_try: await asyncio.sleep(self.rpc_retry_delay) else: raise logger.debug( L( "Enqueued message {} in Redis in {} stream {}", Bold(rpc_message), human_time(time.time() - start_time), Bold(queue_key), ))
async def call_rpc(self, rpc_message: RpcMessage, options: dict): queue_key = f"{rpc_message.api_name}:rpc_queue" expiry_key = f"rpc_expiry_key:{rpc_message.id}" logger.debug( LBullets( L("Enqueuing message {} in Redis stream {}", Bold(rpc_message), Bold(queue_key)), items=dict(**rpc_message.get_metadata(), kwargs=rpc_message.get_kwargs()), )) with await self.connection_manager() as redis: start_time = time.time() print("setting " + expiry_key) p = redis.pipeline() p.rpush(key=queue_key, value=self.serializer(rpc_message)) p.set(expiry_key, 1) p.expire(expiry_key, timeout=self.rpc_timeout) await p.execute() logger.debug( L( "Enqueued message {} in Redis in {} stream {}", Bold(rpc_message), human_time(time.time() - start_time), Bold(queue_key), ))
async def call_rpc_remote(self, api_name: str, name: str, kwargs: dict, options: dict): rpc_message = RpcMessage(api_name=api_name, procedure_name=name, kwargs=kwargs) return_path = self.result_transport.get_return_path(rpc_message) rpc_message.return_path = return_path options = options or {} timeout = options.get('timeout', 5) logger.info("➡ Calling remote RPC ".format(rpc_message)) start_time = time.time() # TODO: It is possible that the RPC will be called before we start waiting for the response. This is bad. future = asyncio.gather( self.receive_result(rpc_message, return_path, options=options), self.rpc_transport.call_rpc(rpc_message, options=options), ) await plugin_hook('before_rpc_call', rpc_message=rpc_message, bus_client=self) try: result_message, _ = await asyncio.wait_for(future, timeout=timeout) except asyncio.TimeoutError: future.cancel() # TODO: Include description of possible causes and how to increase the timeout. # TODO: Remove RPC from queue. Perhaps add a RpcBackend.cancel() method. Optional, # as not all backends will support it. No point processing calls which have timed out. raise LightbusTimeout('Timeout when calling RPC {} after {} seconds'.format( rpc_message.canonical_name, timeout )) from None await plugin_hook('after_rpc_call', rpc_message=rpc_message, result_message=result_message, bus_client=self) if not result_message.error: logger.info(L("⚡ Remote call of {} completed in {}", Bold(rpc_message.canonical_name), human_time(time.time() - start_time))) else: logger.warning( L("⚡ Server error during remote call of {}. Took {}: {}", Bold(rpc_message.canonical_name), human_time(time.time() - start_time), result_message.result, ), ) raise LightbusServerError('Error while calling {}: {}\nRemote stack trace:\n{}'.format( rpc_message.canonical_name, result_message.result, result_message.trace, )) return result_message.result
async def test_receive_result(redis_result_transport: RedisResultTransport, redis_client): redis_client.lpush( key='my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e', value=json.dumps({ 'result': 'All done! 😎', 'rpc_id': '123abc', 'error': False, }), ) result_message = await redis_result_transport.receive_result( rpc_message=RpcMessage( rpc_id='123abc', api_name='my.api', procedure_name='my_proc', kwargs={'field': 'value'}, return_path='abc', ), return_path= 'redis+key://my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e', options={}, ) assert result_message.result == 'All done! 😎' assert result_message.rpc_id == '123abc' assert result_message.error == False
async def test_send_result(redis_result_transport: RedisResultTransport, redis_client): await redis_result_transport.send_result( rpc_message=RpcMessage( rpc_id='123abc', api_name='my.api', procedure_name='my_proc', kwargs={'field': 'value'}, return_path='abc', ), result_message=ResultMessage( rpc_id='123abc', result='All done! 😎', ), return_path= 'redis+key://my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e', ) assert await redis_client.keys('*') == [ b'my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e' ] result = await redis_client.lpop( 'my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e') assert json.loads(result) == { 'error': False, 'rpc_id': '123abc', 'result': 'All done! 😎', }
async def test_send_result(redis_result_transport: RedisResultTransport, redis_client): await redis_result_transport.send_result( rpc_message=RpcMessage( id="123abc", api_name="my.api", procedure_name="my_proc", kwargs={"field": "value"}, return_path="abc", ), result_message=ResultMessage(id="345", rpc_message_id="123abc", result="All done! 😎"), return_path= "redis+key://my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e", bus_client=None, ) assert await redis_client.keys("*") == [ b"my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e" ] result = await redis_client.lpop( "my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e") assert json.loads(result) == { "metadata": { "error": False, "rpc_message_id": "123abc", "id": "345" }, "kwargs": { "result": "All done! 😎" }, }
async def test_call_rpc(redis_rpc_transport, redis_client): """Does call_rpc() add a message to a stream""" rpc_message = RpcMessage( id="123abc", api_name="my.api", procedure_name="my_proc", kwargs={"field": "value"}, return_path="abc", ) await redis_rpc_transport.call_rpc(rpc_message, options={}) assert set(await redis_client.keys("*")) == { b"my.api:rpc_queue", b"rpc_expiry_key:123abc" } messages = await redis_client.lrange("my.api:rpc_queue", start=0, stop=100) assert len(messages) == 1 message = json.loads(messages[0]) assert message == { "metadata": { "id": "123abc", "api_name": "my.api", "procedure_name": "my_proc", "return_path": "abc", }, "kwargs": { "field": "value" }, } assert await redis_client.exists("rpc_expiry_key:123abc") assert await redis_client.ttl("rpc_expiry_key:123abc" ) == redis_rpc_transport.rpc_timeout
async def test_receive_result(redis_result_transport: RedisResultTransport, redis_client): redis_client.lpush( key="my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e", value=json.dumps({ "metadata": { "rpc_message_id": "123abc", "error": False, "id": "123" }, "kwargs": { "result": "All done! 😎" }, }), ) result_message = await redis_result_transport.receive_result( rpc_message=RpcMessage( id="123abc", api_name="my.api", procedure_name="my_proc", kwargs={"field": "value"}, return_path="abc", ), return_path= "redis+key://my.api.my_proc:result:e1821498-e57c-11e7-af9d-7831c1c3936e", options={}, bus_client=None, ) assert result_message.result == "All done! 😎" assert result_message.rpc_message_id == "123abc" assert result_message.id == "123" assert result_message.error == False
async def consume_rpcs(self, apis: Sequence[Api]) -> Sequence[RpcMessage]: # Get the name of each stream streams = ['{}:stream'.format(api.meta.name) for api in apis] # Get where we last left off in each stream latest_ids = [self._latest_ids.get(stream, '$') for stream in streams] logger.debug(LBullets( 'Consuming RPCs from', items=[ '{} ({})'.format(s, self._latest_ids.get(s, '$')) for s in streams ] )) pool = await self.get_redis_pool() with await pool as redis: # TODO: Count/timeout configurable stream_messages = await redis.xread(streams, latest_ids=latest_ids, count=10) rpc_messages = [] for stream, message_id, fields in stream_messages: stream = decode(stream, 'utf8') message_id = decode(message_id, 'utf8') decoded_fields = decode_message_fields(fields) # See comment on events transport re updating message_id self._latest_ids[stream] = message_id rpc_messages.append( RpcMessage.from_dict(decoded_fields) ) logger.debug(LBullets( L("⬅ Received message {} on stream {}", Bold(message_id), Bold(stream)), items=decoded_fields )) return rpc_messages
def _get_fake_messages(self): return [ RpcMessage( api_name="my_company.auth", procedure_name="check_password", kwargs=dict(username="******", password="******"), ) ]
async def dummy_transport_consume_rpcs(*args, **kwargs): if m.call_count == 1: return [ RpcMessage(api_name="my.dummy", procedure_name="my_proc", kwargs={"field": 123}) ] else: raise StopIt()
async def test_local_rpc_call(loop, dummy_bus: BusPath, consume_rpcs, get_dummy_events, mocker): rpc_transport = dummy_bus.client.transport_registry.get_rpc_transport( "default") mocker.patch.object( rpc_transport, "_get_fake_messages", return_value=[ RpcMessage(id="123abc", api_name="example.test", procedure_name="my_method", kwargs={"f": 123}) ], ) # Setup the bus and do the call manually_set_plugins( plugins={ "metrics": MetricsPlugin(service_name="foo", process_name="bar") }) registry.add(TestApi()) task = asyncio.ensure_future(consume_rpcs(dummy_bus), loop=loop) # The dummy transport will fire an every every 0.1 seconds await asyncio.sleep(0.15) await cancel(task) event_messages = get_dummy_events() assert len(event_messages) == 2, event_messages # before_rpc_execution assert event_messages[0].api_name == "internal.metrics" assert event_messages[0].event_name == "rpc_call_received" assert event_messages[0].kwargs.pop("timestamp") assert event_messages[0].kwargs == { "api_name": "example.test", "procedure_name": "my_method", "id": "123abc", "service_name": "foo", "process_name": "bar", } # after_rpc_execution assert event_messages[1].api_name == "internal.metrics" assert event_messages[1].event_name == "rpc_response_sent" assert event_messages[1].kwargs.pop("timestamp") assert event_messages[1].kwargs == { "api_name": "example.test", "procedure_name": "my_method", "id": "123abc", "result": "value", "service_name": "foo", "process_name": "bar", }
async def call_rpc(self, rpc_message: RpcMessage, options: dict): stream = '{}:stream'.format(rpc_message.api_name) logger.debug( LBullets( L("Enqueuing message {} in Redis stream {}", Bold(rpc_message), Bold(stream)), items=rpc_message.to_dict() ) ) pool = await self.get_redis_pool() with await pool as redis: start_time = time.time() # TODO: MAXLEN await redis.xadd(stream=stream, fields=encode_message_fields(rpc_message.to_dict())) logger.info(L( "Enqueued message {} in Redis in {} stream {}", Bold(rpc_message), human_time(time.time() - start_time), Bold(stream) ))
async def test_get_return_path(redis_result_transport: RedisResultTransport): return_path = redis_result_transport.get_return_path( RpcMessage( api_name='my.api', procedure_name='my_proc', kwargs={'field': 'value'}, return_path='abc', )) assert return_path.startswith('redis+key://my.api.my_proc:result:') result_uuid = b64decode(return_path.split(':')[-1]) assert UUID(bytes=result_uuid)
async def dummy_transport_consume_rpcs(*args, **kwargs): if m.call_count == 1: return [ RpcMessage( api_name='my.dummy', procedure_name='my_proc', kwargs={'field': 123}, ) ] else: raise StopIt()
async def test_get_return_path(redis_result_transport: RedisResultTransport): return_path = redis_result_transport.get_return_path( RpcMessage( api_name="my.api", procedure_name="my_proc", kwargs={"field": "value"}, return_path="abc", )) assert return_path.startswith("redis+key://my.api.my_proc:result:") result_uuid = b64decode(return_path.split(":")[-1]) assert UUID(bytes=result_uuid)
async def test_reconnect_upon_call_rpc(redis_rpc_transport, redis_client): """Does call_rpc() add a message to a stream""" # Kill the rpc transport's connection await redis_client.execute(b"CLIENT", b"KILL", b"TYPE", b"NORMAL") # Now send a message and ensure it does so without complaint rpc_message = RpcMessage( id="123abc", api_name="my.api", procedure_name="my_proc", kwargs={"field": "value"}, return_path="abc", ) await redis_rpc_transport.call_rpc(rpc_message, options={}) assert set(await redis_client.keys("*")) == {b"my.api:rpc_queue", b"rpc_expiry_key:123abc"} messages = await redis_client.lrange("my.api:rpc_queue", start=0, stop=100) assert len(messages) == 1
async def test_local_rpc_call(dummy_bus: BusNode, rpc_consumer, get_dummy_events, mocker): mocker.patch.object(dummy_bus.bus_client.rpc_transport, '_get_fake_messages', return_value=[ RpcMessage(rpc_id='123abc', api_name='example.test', procedure_name='my_method', kwargs={'f': 123}) ]) # Setup the bus and do the call manually_set_plugins(plugins={'metrics': MetricsPlugin()}) registry.add(TestApi()) # The dummy transport will fire an every every 0.1 seconds await asyncio.sleep(0.15) event_messages = get_dummy_events() assert len(event_messages) == 2, event_messages # before_rpc_execution assert event_messages[0].api_name == 'internal.metrics' assert event_messages[0].event_name == 'rpc_call_received' assert event_messages[0].kwargs.pop('timestamp') assert event_messages[0].kwargs.pop('process_name') assert event_messages[0].kwargs == { 'api_name': 'example.test', 'procedure_name': 'my_method', 'rpc_id': '123abc', } # after_rpc_execution assert event_messages[1].api_name == 'internal.metrics' assert event_messages[1].event_name == 'rpc_response_sent' assert event_messages[1].kwargs.pop('timestamp') assert event_messages[1].kwargs.pop('process_name') assert event_messages[1].kwargs == { 'api_name': 'example.test', 'procedure_name': 'my_method', 'rpc_id': '123abc', 'result': 'value', }
async def test_call_rpc(redis_rpc_transport, redis_client): """Does call_rpc() add a message to a stream""" rpc_message = RpcMessage( rpc_id='123abc', api_name='my.api', procedure_name='my_proc', kwargs={'field': 'value'}, return_path='abc', ) await redis_rpc_transport.call_rpc(rpc_message, options={}) assert await redis_client.keys('*') == [b'my.api:stream'] messages = await redis_client.xrange('my.api:stream') assert len(messages) == 1 assert messages[0][1] == { b'rpc_id': b'"123abc"', b'api_name': b'"my.api"', b'procedure_name': b'"my_proc"', b'kw:field': b'"value"', b'return_path': b'"abc"', }
async def call_rpc_remote(self, api_name: str, name: str, kwargs: dict = frozendict(), options: dict = frozendict()): rpc_transport = self.transport_registry.get_rpc_transport(api_name) result_transport = self.transport_registry.get_result_transport( api_name) kwargs = deform_to_bus(kwargs) rpc_message = RpcMessage(api_name=api_name, procedure_name=name, kwargs=kwargs) return_path = result_transport.get_return_path(rpc_message) rpc_message.return_path = return_path options = options or {} timeout = options.get("timeout", self.config.api(api_name).rpc_timeout) self._validate_name(api_name, "rpc", name) logger.info("📞 Calling remote RPC {}.{}".format( Bold(api_name), Bold(name))) start_time = time.time() # TODO: It is possible that the RPC will be called before we start waiting for the response. This is bad. self._validate(rpc_message, "outgoing") future = asyncio.gather( self.receive_result(rpc_message, return_path, options=options), rpc_transport.call_rpc(rpc_message, options=options), ) await self._plugin_hook("before_rpc_call", rpc_message=rpc_message) try: result_message, _ = await asyncio.wait_for(future, timeout=timeout) future.result() except asyncio.TimeoutError: # Allow the future to finish, as per https://bugs.python.org/issue29432 try: await future future.result() except CancelledError: pass # TODO: Remove RPC from queue. Perhaps add a RpcBackend.cancel() method. Optional, # as not all backends will support it. No point processing calls which have timed out. raise LightbusTimeout( f"Timeout when calling RPC {rpc_message.canonical_name} after {timeout} seconds. " f"It is possible no Lightbus process is serving this API, or perhaps it is taking " f"too long to process the request. In which case consider raising the 'rpc_timeout' " f"config option.") from None await self._plugin_hook("after_rpc_call", rpc_message=rpc_message, result_message=result_message) if not result_message.error: logger.info( L( "🏁 Remote call of {} completed in {}", Bold(rpc_message.canonical_name), human_time(time.time() - start_time), )) else: logger.warning( L( "⚡ Server error during remote call of {}. Took {}: {}", Bold(rpc_message.canonical_name), human_time(time.time() - start_time), result_message.result, )) raise LightbusServerError( "Error while calling {}: {}\nRemote stack trace:\n{}".format( rpc_message.canonical_name, result_message.result, result_message.trace)) self._validate(result_message, "incoming", api_name, procedure_name=name) return result_message.result
async def call_rpc_remote( self, api_name: str, name: str, kwargs: dict = frozendict(), options: dict = frozendict() ): """ Perform an RPC call Call an RPC and return the result. """ kwargs = deform_to_bus(kwargs) rpc_message = RpcMessage(api_name=api_name, procedure_name=name, kwargs=kwargs) validate_event_or_rpc_name(api_name, "rpc", name) logger.info("📞 Calling remote RPC {}.{}".format(Bold(api_name), Bold(name))) start_time = time.time() validate_outgoing(self.config, self.schema, rpc_message) await self.hook_registry.execute("before_rpc_call", rpc_message=rpc_message) result_queue = InternalQueue() # Send the RPC await self.producer.send( commands.CallRpcCommand(message=rpc_message, options=options) ).wait() # Start a listener which will wait for results await self.producer.send( commands.ReceiveResultCommand( message=rpc_message, destination_queue=result_queue, options=options ) ).wait() # Wait for the result from the listener we started. # The RpcResultDock will handle timeouts result = await bail_on_error(self.error_queue, result_queue.get()) call_time = time.time() - start_time try: if isinstance(result, Exception): raise result except asyncio.TimeoutError: raise LightbusTimeout( f"Timeout when calling RPC {rpc_message.canonical_name} after waiting for {human_time(call_time)}. " f"It is possible no Lightbus process is serving this API, or perhaps it is taking " f"too long to process the request. In which case consider raising the 'rpc_timeout' " f"config option." ) from None else: assert isinstance(result, ResultMessage) result_message = result await self.hook_registry.execute( "after_rpc_call", rpc_message=rpc_message, result_message=result_message ) if not result_message.error: logger.info( L( "🏁 Remote call of {} completed in {}", Bold(rpc_message.canonical_name), human_time(call_time), ) ) else: logger.warning( L( "⚡ Error during remote call of RPC {}. Took {}: {}", Bold(rpc_message.canonical_name), human_time(call_time), result_message.result, ) ) raise LightbusWorkerError( "Error while calling {}: {}\nRemote stack trace:\n{}".format( rpc_message.canonical_name, result_message.result, result_message.trace ) ) validate_incoming(self.config, self.schema, result_message) return result_message.result
def _get_fake_messages(self): return [RpcMessage(api_name='my_company.auth', procedure_name='check_password', kwargs=dict( username='******', password='******', ))]