async def test_pubsub_receiver_iter(create_redis, server, loop): sub = await create_redis(server.tcp_address, loop=loop) pub = await create_redis(server.tcp_address, loop=loop) mpsc = Receiver(loop=loop) async def coro(mpsc): lst = [] async for msg in mpsc.iter(): lst.append(msg) return lst tsk = asyncio.ensure_future(coro(mpsc), loop=loop) snd1, = await sub.subscribe(mpsc.channel('chan:1')) snd2, = await sub.subscribe(mpsc.channel('chan:2')) snd3, = await sub.psubscribe(mpsc.pattern('chan:*')) subscribers = await pub.publish_json('chan:1', {'Hello': 'World'}) assert subscribers > 1 subscribers = await pub.publish_json('chan:2', ['message']) assert subscribers > 1 loop.call_later(0, mpsc.stop) # await asyncio.sleep(0, loop=loop) assert await tsk == [ (snd1, b'{"Hello": "World"}'), (snd3, (b'chan:1', b'{"Hello": "World"}')), (snd2, b'["message"]'), (snd3, (b'chan:2', b'["message"]')), ] assert not mpsc.is_active
async def test_decode_message_error(loop): mpsc = Receiver(loop) ch = mpsc.channel('channel:1') ch.put_nowait(b'{"hello": "world"}') unexpected = (mock.ANY, {'hello': 'world'}) with pytest.raises(TypeError): assert (await mpsc.get(decoder=json.loads)) == unexpected ch = mpsc.pattern('*') ch.put_nowait((b'channel', b'{"hello": "world"}')) unexpected = (mock.ANY, b'channel', {'hello': 'world'}) with pytest.raises(TypeError): assert (await mpsc.get(decoder=json.loads)) == unexpected
def test_listener_pattern(loop): mpsc = Receiver(loop=loop) assert not mpsc.is_active ch_a = mpsc.pattern("*") assert isinstance(ch_a, AbcChannel) assert mpsc.is_active ch_b = mpsc.pattern('*') assert ch_a is ch_b assert ch_a.name == ch_b.name assert ch_a.is_pattern == ch_b.is_pattern assert mpsc.is_active # remember id; drop refs to objects and create new one; ch_a.close() assert not ch_a.is_active assert not mpsc.is_active ch = mpsc.pattern("*") assert ch is not ch_a assert dict(mpsc.channels) == {} assert dict(mpsc.patterns) == {b'*': ch}
async def run(self): p = self._dashboard_head.aioredis_client mpsc = Receiver() reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX) await p.psubscribe(mpsc.pattern(reporter_key)) logger.info("Subscribed to {}".format(reporter_key)) async for sender, msg in mpsc.iter(): try: _, data = msg data = json.loads(ray.utils.decode(data)) DataSource.node_physical_stats[data["ip"]] = data except Exception as ex: logger.exception(ex)
async def run(self, server): aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX) await aioredis_client.psubscribe(receiver.pattern(reporter_key)) logger.info(f"Subscribed to {reporter_key}") async for sender, msg in receiver.iter(): try: _, data = msg data = json.loads(ray.utils.decode(data)) DataSource.node_physical_stats[data["ip"]] = data except Exception: logger.exception( "Error receiving node physical stats from reporter agent.")
async def _update_jobs(self): # Subscribe job channel. aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() key = f"{job_consts.JOB_CHANNEL}:*" pattern = receiver.pattern(key) await aioredis_client.psubscribe(pattern) logger.info("Subscribed to %s", key) # Get all job info. while True: try: logger.info("Getting all job info from GCS.") request = gcs_service_pb2.GetAllJobInfoRequest() reply = await self._gcs_job_info_stub.GetAllJobInfo(request, timeout=5) if reply.status.code == 0: jobs = {} for job_table_data in reply.job_info_list: data = job_table_data_to_dict(job_table_data) jobs[data["jobId"]] = data # Update jobs. DataSource.jobs.reset(jobs) logger.info("Received %d job info from GCS.", len(jobs)) break else: raise Exception( f"Failed to GetAllJobInfo: {reply.status.message}") except Exception: logger.exception("Error Getting all job info from GCS.") await asyncio.sleep( job_consts.RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS) # Receive jobs from channel. async for sender, msg in receiver.iter(): try: _, data = msg pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data) message = ray.gcs_utils.JobTableData.FromString( pubsub_message.data) job_table_data = job_table_data_to_dict(message) job_id = job_table_data["jobId"] # Update jobs. DataSource.jobs[job_id] = job_table_data except Exception: logger.exception("Error receiving job info.")
async def _update_error_info(self): def process_error(error_data): message = error_data.error_message message = re.sub(r"\x1b\[\d+m", "", message) match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message) if match: pid = match.group(1) ip = match.group(2) errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {})) pid_errors = list(errs_for_ip.get(pid, [])) pid_errors.append({ "message": message, "timestamp": error_data.timestamp, "type": error_data.type }) errs_for_ip[pid] = pid_errors DataSource.ip_and_pid_to_errors[ip] = errs_for_ip logger.info(f"Received error entry for {ip} {pid}") if self._dashboard_head.gcs_error_subscriber: while True: try: _, error_data = await \ self._dashboard_head.gcs_error_subscriber.poll() if error_data is None: continue process_error(error_data) except Exception: logger.exception("Error receiving error info from GCS.") else: aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN pattern = receiver.pattern(key) await aioredis_client.psubscribe(pattern) logger.info("Subscribed to %s", key) async for _, msg in receiver.iter(): try: _, data = msg pubsub_msg = gcs_utils.PubSubMessage.FromString(data) error_data = gcs_utils.ErrorTableData.FromString( pubsub_msg.data) process_error(error_data) except Exception: logger.exception("Error receiving error info from Redis.")
async def _update_actors(self): # Subscribe actor channel. aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() key = "{}:*".format(stats_collector_consts.ACTOR_CHANNEL) pattern = receiver.pattern(key) await aioredis_client.psubscribe(pattern) logger.info("Subscribed to %s", key) # Get all actor info. while True: try: logger.info("Getting all actor info from GCS.") request = gcs_service_pb2.GetAllActorInfoRequest() reply = await self._gcs_actor_info_stub.GetAllActorInfo( request, timeout=2) if reply.status.code == 0: result = {} for actor_info in reply.actor_table_data: result[binary_to_hex(actor_info.actor_id)] = \ actor_table_data_to_dict(actor_info) DataSource.actors.reset(result) logger.info("Received %d actor info from GCS.", len(result)) break else: raise Exception( f"Failed to GetAllActorInfo: {reply.status.message}") except Exception: logger.exception("Error Getting all actor info from GCS.") await asyncio.sleep(stats_collector_consts. RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS) # Receive actors from channel. async for sender, msg in receiver.iter(): try: _, data = msg pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data) actor_info = ray.gcs_utils.ActorTableData.FromString( pubsub_message.data) DataSource.actors[binary_to_hex(actor_info.actor_id)] = \ actor_table_data_to_dict(actor_info) except Exception: logger.exception("Error receiving actor info.")
async def test_decode_message_for_pattern(): mpsc = Receiver() ch = mpsc.pattern("*") ch.put_nowait((b"channel", b"Some data")) res = await mpsc.get(encoding="utf-8") assert isinstance(res[0], _Sender) assert res[1] == (b"channel", "Some data") ch.put_nowait((b"channel", '{"hello": "world"}')) res = await mpsc.get(decoder=json.loads) assert isinstance(res[0], _Sender) assert res[1] == (b"channel", {"hello": "world"}) ch.put_nowait((b"channel", b'{"hello": "world"}')) res = await mpsc.get(encoding="utf-8", decoder=json.loads) assert isinstance(res[0], _Sender) assert res[1] == (b"channel", {"hello": "world"})
async def test_decode_message_for_pattern(loop): mpsc = Receiver(loop) ch = mpsc.pattern('*') ch.put_nowait((b'channel', b'Some data')) res = await mpsc.get(encoding='utf-8') assert isinstance(res[0], _Sender) assert res[1] == (b'channel', 'Some data') ch.put_nowait((b'channel', '{"hello": "world"}')) res = await mpsc.get(decoder=json.loads) assert isinstance(res[0], _Sender) assert res[1] == (b'channel', {'hello': 'world'}) ch.put_nowait((b'channel', b'{"hello": "world"}')) res = await mpsc.get(encoding='utf-8', decoder=json.loads) assert isinstance(res[0], _Sender) assert res[1] == (b'channel', {'hello': 'world'})
async def run(self, server): aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX) await aioredis_client.psubscribe(receiver.pattern(reporter_key)) logger.info(f"Subscribed to {reporter_key}") async for sender, msg in receiver.iter(): try: # The key is b'RAY_REPORTER:{node id hex}', # e.g. b'RAY_REPORTER:2b4fbd406898cc86fb88fb0acfd5456b0afd87cf' key, data = msg data = json.loads(ray._private.utils.decode(data)) key = key.decode("utf-8") node_id = key.split(":")[-1] DataSource.node_physical_stats[node_id] = data except Exception: logger.exception( "Error receiving node physical stats from reporter agent.")
async def run(self, server): # Need daemon True to avoid dashboard hangs at exit. self.service_discovery.daemon = True self.service_discovery.start() if gcs_pubsub_enabled(): gcs_addr = self._dashboard_head.gcs_address subscriber = GcsAioResourceUsageSubscriber(gcs_addr) await subscriber.subscribe() while True: try: # The key is b'RAY_REPORTER:{node id hex}', # e.g. b'RAY_REPORTER:2b4fbd...' key, data = await subscriber.poll() if key is None: continue data = json.loads(data) node_id = key.split(":")[-1] DataSource.node_physical_stats[node_id] = data except Exception: logger.exception("Error receiving node physical stats " "from reporter agent.") else: from aioredis.pubsub import Receiver receiver = Receiver() aioredis_client = self._dashboard_head.aioredis_client reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX) await aioredis_client.psubscribe(receiver.pattern(reporter_key)) logger.info(f"Subscribed to {reporter_key}") async for sender, msg in receiver.iter(): try: key, data = msg data = json.loads(ray._private.utils.decode(data)) key = key.decode("utf-8") node_id = key.split(":")[-1] DataSource.node_physical_stats[node_id] = data except Exception: logger.exception("Error receiving node physical stats " "from reporter agent.")
async def test_pubsub_receiver_stop_on_disconnect(create_redis, server, loop): pub = await create_redis(server.tcp_address, loop=loop) sub = await create_redis(server.tcp_address, loop=loop) sub_name = 'sub-{:X}'.format(id(sub)) await sub.client_setname(sub_name) for sub_info in await pub.client_list(): if sub_info.name == sub_name: break assert sub_info.name == sub_name mpsc = Receiver(loop=loop) await sub.subscribe(mpsc.channel('channel:1')) await sub.subscribe(mpsc.channel('channel:2')) await sub.psubscribe(mpsc.pattern('channel:*')) q = asyncio.Queue(loop=loop) EOF = object() async def reader(): async for ch, msg in mpsc.iter(encoding='utf-8'): await q.put((ch.name, msg)) await q.put(EOF) tsk = asyncio.ensure_future(reader(), loop=loop) await pub.publish_json('channel:1', ['hello']) await pub.publish_json('channel:2', ['hello']) # receive all messages assert await q.get() == (b'channel:1', '["hello"]') assert await q.get() == (b'channel:*', (b'channel:1', '["hello"]')) assert await q.get() == (b'channel:2', '["hello"]') assert await q.get() == (b'channel:*', (b'channel:2', '["hello"]')) # XXX: need to implement `client kill` assert await pub.execute('client', 'kill', sub_info.addr) in (b'OK', 1) await asyncio.wait_for(tsk, timeout=1, loop=loop) assert await q.get() is EOF
async def test_pubsub_receiver_stop_on_disconnect(create_redis, server): pub = await create_redis(server.tcp_address) sub = await create_redis(server.tcp_address) sub_name = "sub-{:X}".format(id(sub)) await sub.client_setname(sub_name) for sub_info in await pub.client_list(): if sub_info.name == sub_name: break assert sub_info.name == sub_name mpsc = Receiver() await sub.subscribe(mpsc.channel("channel:1")) await sub.subscribe(mpsc.channel("channel:2")) await sub.psubscribe(mpsc.pattern("channel:*")) q = asyncio.Queue() EOF = object() async def reader(): async for ch, msg in mpsc.iter(encoding="utf-8"): await q.put((ch.name, msg)) await q.put(EOF) tsk = asyncio.ensure_future(reader()) await pub.publish_json("channel:1", ["hello"]) await pub.publish_json("channel:2", ["hello"]) # receive all messages assert await q.get() == (b"channel:1", '["hello"]') assert await q.get() == (b"channel:*", (b"channel:1", '["hello"]')) assert await q.get() == (b"channel:2", '["hello"]') assert await q.get() == (b"channel:*", (b"channel:2", '["hello"]')) # XXX: need to implement `client kill` assert await pub.execute("client", "kill", sub_info.addr) in (b"OK", 1) await asyncio.wait_for(tsk, timeout=1) assert await q.get() is EOF
class WebsocketServer: """Provide websocket proxy to public redis channels and hashes. In addition to passing through data this server supports some geo operations, see the WebsocketHandler and protocol.CommandsMixin for details. """ handler_class = WebsocketHandler def __init__(self, redis, subscriber, read_timeout=None, keep_alive_timeout=None): """Set default values for new WebsocketHandlers. :param redis: aioredis.StrictRedis instance :param subscriber: aioredis.StrictRedis instance :param read_timeout: Timeout, after which the websocket connection is checked and kept if still open (does not cancel an open connection) :param keep_alive_timeout: Time after which the server cancels the handler task (independently of it's internal state) """ self.read_timeout = read_timeout self.keep_alive_timeout = keep_alive_timeout self.receiver = Receiver() self.handlers = {} self.redis = redis self.subscriber = subscriber async def websocket_handler(self, websocket, path): """Return handler for a single websocket connection.""" logger.info("Client %s connected", websocket.remote_address) handler = await self.handler_class.create( self.redis, websocket, set(map(bytes.decode, self.receiver.channels.keys())), set(map(bytes.decode, self.receiver.patterns.keys())), read_timeout=self.read_timeout, ) self.handlers[websocket.remote_address] = handler try: await asyncio.wait_for(handler.listen(), self.keep_alive_timeout) finally: del self.handlers[websocket.remote_address] await handler.close() logger.info("Client %s removed", websocket.remote_address) async def redis_subscribe(self, channel_names=None, channel_patterns=None): """Subscribe to channels by channel_names and/or channel_patterns.""" if not (channel_names or channel_patterns): raise ValueError("Got nothing to subscribe to") if channel_names: await self.subscriber.subscribe(*(self.receiver.channel(name) for name in channel_names)) if channel_patterns: await self.subscriber.psubscribe(*(self.receiver.pattern(pattern) for pattern in channel_patterns) ) async def redis_reader(self): """Pass messages from subscribed channels to handlers.""" async for channel, msg in self.receiver.iter(encoding="utf-8"): if channel.is_pattern: channel_name, msg = msg[0].decode(), msg[1] else: channel_name = channel.name.decode() for handler in self.handlers.values(): if channel_name in handler.subscriptions: handler.queue.put_nowait( Message(source=channel_name, content=msg)) def listen(self, host, port, channel_names=None, channel_patterns=None, loop=None): """Listen for websocket connections and manage redis subscriptions.""" loop = loop or asyncio.get_event_loop() start_server = serve(self.websocket_handler, host, port) loop.run_until_complete( self.redis_subscribe(channel_names, channel_patterns)) loop.run_until_complete(start_server) logger.info("Listening on %s:%s...", host, port) loop.run_until_complete(self.redis_reader())
async def _update_actors(self): # TODO(fyrestone): Refactor code for updating actor / node / job. # Subscribe actor channel. aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() key = "{}:*".format(actor_consts.ACTOR_CHANNEL) pattern = receiver.pattern(key) await aioredis_client.psubscribe(pattern) logger.info("Subscribed to %s", key) def _process_actor_table_data(data): actor_class = actor_classname_from_task_spec( data.get("taskSpec", {})) data["actorClass"] = actor_class # Get all actor info. while True: try: logger.info("Getting all actor info from GCS.") request = gcs_service_pb2.GetAllActorInfoRequest() reply = await self._gcs_actor_info_stub.GetAllActorInfo( request, timeout=5) if reply.status.code == 0: actors = {} for message in reply.actor_table_data: actor_table_data = actor_table_data_to_dict(message) _process_actor_table_data(actor_table_data) actors[actor_table_data["actorId"]] = actor_table_data # Update actors. DataSource.actors.reset(actors) # Update node actors and job actors. job_actors = {} node_actors = {} for actor_id, actor_table_data in actors.items(): job_id = actor_table_data["jobId"] node_id = actor_table_data["address"]["rayletId"] job_actors.setdefault(job_id, {})[actor_id] = actor_table_data # Update only when node_id is not Nil. if node_id != actor_consts.NIL_NODE_ID: node_actors.setdefault( node_id, {})[actor_id] = actor_table_data DataSource.job_actors.reset(job_actors) DataSource.node_actors.reset(node_actors) logger.info("Received %d actor info from GCS.", len(actors)) break else: raise Exception( f"Failed to GetAllActorInfo: {reply.status.message}") except Exception: logger.exception("Error Getting all actor info from GCS.") await asyncio.sleep( actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS) # Receive actors from channel. state_keys = ("state", "address", "numRestarts", "timestamp", "pid") async for sender, msg in receiver.iter(): try: actor_id, actor_table_data = msg pubsub_message = ray.gcs_utils.PubSubMessage.FromString( actor_table_data) message = ray.gcs_utils.ActorTableData.FromString( pubsub_message.data) actor_table_data = actor_table_data_to_dict(message) _process_actor_table_data(actor_table_data) # If actor is not new registered but updated, we only update # states related fields. if actor_table_data["state"] != "DEPENDENCIES_UNREADY": actor_id = actor_id.decode( "UTF-8")[len(ray.gcs_utils.TablePrefix_ACTOR_string + ":"):] actor_table_data_copy = dict(DataSource.actors[actor_id]) for k in state_keys: actor_table_data_copy[k] = actor_table_data[k] actor_table_data = actor_table_data_copy actor_id = actor_table_data["actorId"] job_id = actor_table_data["jobId"] node_id = actor_table_data["address"]["rayletId"] # Update actors. DataSource.actors[actor_id] = actor_table_data # Update node actors (only when node_id is not Nil). if node_id != actor_consts.NIL_NODE_ID: node_actors = dict(DataSource.node_actors.get(node_id, {})) node_actors[actor_id] = actor_table_data DataSource.node_actors[node_id] = node_actors # Update job actors. job_actors = dict(DataSource.job_actors.get(job_id, {})) job_actors[actor_id] = actor_table_data DataSource.job_actors[job_id] = job_actors except Exception: logger.exception("Error receiving actor info.")
async def _update_actors(self): # Subscribe actor channel. aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() key = "{}:*".format(stats_collector_consts.ACTOR_CHANNEL) pattern = receiver.pattern(key) await aioredis_client.psubscribe(pattern) logger.info("Subscribed to %s", key) def _process_actor_table_data(data): actor_class = actor_classname_from_task_spec( data.get("taskSpec", {})) data["actorClass"] = actor_class # Get all actor info. while True: try: logger.info("Getting all actor info from GCS.") request = gcs_service_pb2.GetAllActorInfoRequest() reply = await self._gcs_actor_info_stub.GetAllActorInfo( request, timeout=5) if reply.status.code == 0: actors = {} for message in reply.actor_table_data: actor_table_data = actor_table_data_to_dict(message) _process_actor_table_data(actor_table_data) actors[actor_table_data["actorId"]] = actor_table_data # Update actors. DataSource.actors.reset(actors) # Update node actors and job actors. job_actors = {} node_actors = {} for actor_id, actor_table_data in actors.items(): job_id = actor_table_data["jobId"] node_id = actor_table_data["address"]["rayletId"] job_actors.setdefault(job_id, {})[actor_id] = actor_table_data node_actors.setdefault(node_id, {})[actor_id] = actor_table_data DataSource.job_actors.reset(job_actors) DataSource.node_actors.reset(node_actors) logger.info("Received %d actor info from GCS.", len(actors)) break else: raise Exception( f"Failed to GetAllActorInfo: {reply.status.message}") except Exception: logger.exception("Error Getting all actor info from GCS.") await asyncio.sleep(stats_collector_consts. RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS) # Receive actors from channel. async for sender, msg in receiver.iter(): try: _, actor_table_data = msg pubsub_message = ray.gcs_utils.PubSubMessage.FromString( actor_table_data) message = ray.gcs_utils.ActorTableData.FromString( pubsub_message.data) actor_table_data = actor_table_data_to_dict(message) _process_actor_table_data(actor_table_data) actor_id = actor_table_data["actorId"] job_id = actor_table_data["jobId"] node_id = actor_table_data["address"]["rayletId"] # Update actors. DataSource.actors[actor_id] = actor_table_data # Update node actors. node_actors = dict(DataSource.node_actors.get(node_id, {})) node_actors[actor_id] = actor_table_data DataSource.node_actors[node_id] = node_actors # Update job actors. job_actors = dict(DataSource.job_actors.get(job_id, {})) job_actors[actor_id] = actor_table_data DataSource.job_actors[job_id] = job_actors except Exception: logger.exception("Error receiving actor info.")
from aioredis.pubsub import Receiver from aioredis.abc import AbcChannel mpsc = Receiver(loop=loop) async def reader(mpsc): async for channel, msg in mpsc.iter(): assert isinstance(channel, AbcChannel) print("Got {!r} in channel {!r}".format(msg, channel)) asyncio.ensure_future(reader(mpsc)) await redis.subscribe(mpsc.channel('channel:1'), mpsc.channel('channel:3'), mpsc.channel('channel:5')) await redis.psubscribe(mpsc.pattern('hello')) # publishing 'Hello world' into 'hello-channel' # will print this message: # when all is done: await redis.unsubscribe('channel:1', 'channel:3', 'channel:5') await redis.punsubscribe('hello') mpsc.stop() # any message received after stop() will be ignored.