async def test_pubsub_receiver_iter(create_redis, server, loop):
    sub = await create_redis(server.tcp_address, loop=loop)
    pub = await create_redis(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)

    async def coro(mpsc):
        lst = []
        async for msg in mpsc.iter():
            lst.append(msg)
        return lst

    tsk = asyncio.ensure_future(coro(mpsc), loop=loop)
    snd1, = await sub.subscribe(mpsc.channel('chan:1'))
    snd2, = await sub.subscribe(mpsc.channel('chan:2'))
    snd3, = await sub.psubscribe(mpsc.pattern('chan:*'))

    subscribers = await pub.publish_json('chan:1', {'Hello': 'World'})
    assert subscribers > 1
    subscribers = await pub.publish_json('chan:2', ['message'])
    assert subscribers > 1
    loop.call_later(0, mpsc.stop)
    # await asyncio.sleep(0, loop=loop)
    assert await tsk == [
        (snd1, b'{"Hello": "World"}'),
        (snd3, (b'chan:1', b'{"Hello": "World"}')),
        (snd2, b'["message"]'),
        (snd3, (b'chan:2', b'["message"]')),
        ]
    assert not mpsc.is_active
async def test_decode_message_error(loop):
    mpsc = Receiver(loop)
    ch = mpsc.channel('channel:1')

    ch.put_nowait(b'{"hello": "world"}')
    unexpected = (mock.ANY, {'hello': 'world'})
    with pytest.raises(TypeError):
        assert (await mpsc.get(decoder=json.loads)) == unexpected

    ch = mpsc.pattern('*')
    ch.put_nowait((b'channel', b'{"hello": "world"}'))
    unexpected = (mock.ANY, b'channel', {'hello': 'world'})
    with pytest.raises(TypeError):
        assert (await mpsc.get(decoder=json.loads)) == unexpected
def test_listener_pattern(loop):
    mpsc = Receiver(loop=loop)
    assert not mpsc.is_active

    ch_a = mpsc.pattern("*")
    assert isinstance(ch_a, AbcChannel)
    assert mpsc.is_active

    ch_b = mpsc.pattern('*')
    assert ch_a is ch_b
    assert ch_a.name == ch_b.name
    assert ch_a.is_pattern == ch_b.is_pattern
    assert mpsc.is_active

    # remember id; drop refs to objects and create new one;
    ch_a.close()
    assert not ch_a.is_active

    assert not mpsc.is_active
    ch = mpsc.pattern("*")
    assert ch is not ch_a

    assert dict(mpsc.channels) == {}
    assert dict(mpsc.patterns) == {b'*': ch}
Example #4
0
    async def run(self):
        p = self._dashboard_head.aioredis_client
        mpsc = Receiver()

        reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
        await p.psubscribe(mpsc.pattern(reporter_key))
        logger.info("Subscribed to {}".format(reporter_key))

        async for sender, msg in mpsc.iter():
            try:
                _, data = msg
                data = json.loads(ray.utils.decode(data))
                DataSource.node_physical_stats[data["ip"]] = data
            except Exception as ex:
                logger.exception(ex)
def test_listener_pattern(loop):
    mpsc = Receiver(loop=loop)
    assert not mpsc.is_active

    ch_a = mpsc.pattern("*")
    assert isinstance(ch_a, AbcChannel)
    assert mpsc.is_active

    ch_b = mpsc.pattern('*')
    assert ch_a is ch_b
    assert ch_a.name == ch_b.name
    assert ch_a.is_pattern == ch_b.is_pattern
    assert mpsc.is_active

    # remember id; drop refs to objects and create new one;
    ch_a.close()
    assert not ch_a.is_active

    assert not mpsc.is_active
    ch = mpsc.pattern("*")
    assert ch is not ch_a

    assert dict(mpsc.channels) == {}
    assert dict(mpsc.patterns) == {b'*': ch}
Example #6
0
    async def run(self, server):
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
        await aioredis_client.psubscribe(receiver.pattern(reporter_key))
        logger.info(f"Subscribed to {reporter_key}")

        async for sender, msg in receiver.iter():
            try:
                _, data = msg
                data = json.loads(ray.utils.decode(data))
                DataSource.node_physical_stats[data["ip"]] = data
            except Exception:
                logger.exception(
                    "Error receiving node physical stats from reporter agent.")
Example #7
0
    async def _update_jobs(self):
        # Subscribe job channel.
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = f"{job_consts.JOB_CHANNEL}:*"
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        # Get all job info.
        while True:
            try:
                logger.info("Getting all job info from GCS.")
                request = gcs_service_pb2.GetAllJobInfoRequest()
                reply = await self._gcs_job_info_stub.GetAllJobInfo(request,
                                                                    timeout=5)
                if reply.status.code == 0:
                    jobs = {}
                    for job_table_data in reply.job_info_list:
                        data = job_table_data_to_dict(job_table_data)
                        jobs[data["jobId"]] = data
                    # Update jobs.
                    DataSource.jobs.reset(jobs)
                    logger.info("Received %d job info from GCS.", len(jobs))
                    break
                else:
                    raise Exception(
                        f"Failed to GetAllJobInfo: {reply.status.message}")
            except Exception:
                logger.exception("Error Getting all job info from GCS.")
                await asyncio.sleep(
                    job_consts.RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS)

        # Receive jobs from channel.
        async for sender, msg in receiver.iter():
            try:
                _, data = msg
                pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data)
                message = ray.gcs_utils.JobTableData.FromString(
                    pubsub_message.data)
                job_table_data = job_table_data_to_dict(message)
                job_id = job_table_data["jobId"]
                # Update jobs.
                DataSource.jobs[job_id] = job_table_data
            except Exception:
                logger.exception("Error receiving job info.")
Example #8
0
    async def _update_error_info(self):
        def process_error(error_data):
            message = error_data.error_message
            message = re.sub(r"\x1b\[\d+m", "", message)
            match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
            if match:
                pid = match.group(1)
                ip = match.group(2)
                errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
                pid_errors = list(errs_for_ip.get(pid, []))
                pid_errors.append({
                    "message": message,
                    "timestamp": error_data.timestamp,
                    "type": error_data.type
                })
                errs_for_ip[pid] = pid_errors
                DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
                logger.info(f"Received error entry for {ip} {pid}")

        if self._dashboard_head.gcs_error_subscriber:
            while True:
                try:
                    _, error_data = await \
                        self._dashboard_head.gcs_error_subscriber.poll()
                    if error_data is None:
                        continue
                    process_error(error_data)
                except Exception:
                    logger.exception("Error receiving error info from GCS.")
        else:
            aioredis_client = self._dashboard_head.aioredis_client
            receiver = Receiver()

            key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
            pattern = receiver.pattern(key)
            await aioredis_client.psubscribe(pattern)
            logger.info("Subscribed to %s", key)

            async for _, msg in receiver.iter():
                try:
                    _, data = msg
                    pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
                    error_data = gcs_utils.ErrorTableData.FromString(
                        pubsub_msg.data)
                    process_error(error_data)
                except Exception:
                    logger.exception("Error receiving error info from Redis.")
Example #9
0
    async def _update_actors(self):
        # Subscribe actor channel.
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = "{}:*".format(stats_collector_consts.ACTOR_CHANNEL)
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        # Get all actor info.
        while True:
            try:
                logger.info("Getting all actor info from GCS.")
                request = gcs_service_pb2.GetAllActorInfoRequest()
                reply = await self._gcs_actor_info_stub.GetAllActorInfo(
                    request, timeout=2)
                if reply.status.code == 0:
                    result = {}
                    for actor_info in reply.actor_table_data:
                        result[binary_to_hex(actor_info.actor_id)] = \
                            actor_table_data_to_dict(actor_info)
                    DataSource.actors.reset(result)
                    logger.info("Received %d actor info from GCS.",
                                len(result))
                    break
                else:
                    raise Exception(
                        f"Failed to GetAllActorInfo: {reply.status.message}")
            except Exception:
                logger.exception("Error Getting all actor info from GCS.")
                await asyncio.sleep(stats_collector_consts.
                                    RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)

        # Receive actors from channel.
        async for sender, msg in receiver.iter():
            try:
                _, data = msg
                pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data)
                actor_info = ray.gcs_utils.ActorTableData.FromString(
                    pubsub_message.data)
                DataSource.actors[binary_to_hex(actor_info.actor_id)] = \
                    actor_table_data_to_dict(actor_info)
            except Exception:
                logger.exception("Error receiving actor info.")
Example #10
0
async def test_decode_message_for_pattern():
    mpsc = Receiver()
    ch = mpsc.pattern("*")
    ch.put_nowait((b"channel", b"Some data"))

    res = await mpsc.get(encoding="utf-8")
    assert isinstance(res[0], _Sender)
    assert res[1] == (b"channel", "Some data")

    ch.put_nowait((b"channel", '{"hello": "world"}'))
    res = await mpsc.get(decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == (b"channel", {"hello": "world"})

    ch.put_nowait((b"channel", b'{"hello": "world"}'))
    res = await mpsc.get(encoding="utf-8", decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == (b"channel", {"hello": "world"})
Example #11
0
async def test_decode_message_for_pattern(loop):
    mpsc = Receiver(loop)
    ch = mpsc.pattern('*')
    ch.put_nowait((b'channel', b'Some data'))

    res = await mpsc.get(encoding='utf-8')
    assert isinstance(res[0], _Sender)
    assert res[1] == (b'channel', 'Some data')

    ch.put_nowait((b'channel', '{"hello": "world"}'))
    res = await mpsc.get(decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == (b'channel', {'hello': 'world'})

    ch.put_nowait((b'channel', b'{"hello": "world"}'))
    res = await mpsc.get(encoding='utf-8', decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == (b'channel', {'hello': 'world'})
Example #12
0
async def test_decode_message_for_pattern(loop):
    mpsc = Receiver(loop)
    ch = mpsc.pattern('*')
    ch.put_nowait((b'channel', b'Some data'))

    res = await mpsc.get(encoding='utf-8')
    assert isinstance(res[0], _Sender)
    assert res[1] == (b'channel', 'Some data')

    ch.put_nowait((b'channel', '{"hello": "world"}'))
    res = await mpsc.get(decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == (b'channel', {'hello': 'world'})

    ch.put_nowait((b'channel', b'{"hello": "world"}'))
    res = await mpsc.get(encoding='utf-8', decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == (b'channel', {'hello': 'world'})
Example #13
0
    async def run(self, server):
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
        await aioredis_client.psubscribe(receiver.pattern(reporter_key))
        logger.info(f"Subscribed to {reporter_key}")

        async for sender, msg in receiver.iter():
            try:
                # The key is b'RAY_REPORTER:{node id hex}',
                # e.g. b'RAY_REPORTER:2b4fbd406898cc86fb88fb0acfd5456b0afd87cf'
                key, data = msg
                data = json.loads(ray._private.utils.decode(data))
                key = key.decode("utf-8")
                node_id = key.split(":")[-1]
                DataSource.node_physical_stats[node_id] = data
            except Exception:
                logger.exception(
                    "Error receiving node physical stats from reporter agent.")
Example #14
0
    async def run(self, server):
        # Need daemon True to avoid dashboard hangs at exit.
        self.service_discovery.daemon = True
        self.service_discovery.start()
        if gcs_pubsub_enabled():
            gcs_addr = self._dashboard_head.gcs_address
            subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
            await subscriber.subscribe()

            while True:
                try:
                    # The key is b'RAY_REPORTER:{node id hex}',
                    # e.g. b'RAY_REPORTER:2b4fbd...'
                    key, data = await subscriber.poll()
                    if key is None:
                        continue
                    data = json.loads(data)
                    node_id = key.split(":")[-1]
                    DataSource.node_physical_stats[node_id] = data
                except Exception:
                    logger.exception("Error receiving node physical stats "
                                     "from reporter agent.")
        else:
            from aioredis.pubsub import Receiver

            receiver = Receiver()
            aioredis_client = self._dashboard_head.aioredis_client
            reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
            await aioredis_client.psubscribe(receiver.pattern(reporter_key))
            logger.info(f"Subscribed to {reporter_key}")

            async for sender, msg in receiver.iter():
                try:
                    key, data = msg
                    data = json.loads(ray._private.utils.decode(data))
                    key = key.decode("utf-8")
                    node_id = key.split(":")[-1]
                    DataSource.node_physical_stats[node_id] = data
                except Exception:
                    logger.exception("Error receiving node physical stats "
                                     "from reporter agent.")
Example #15
0
async def test_pubsub_receiver_stop_on_disconnect(create_redis, server, loop):
    pub = await create_redis(server.tcp_address, loop=loop)
    sub = await create_redis(server.tcp_address, loop=loop)
    sub_name = 'sub-{:X}'.format(id(sub))
    await sub.client_setname(sub_name)
    for sub_info in await pub.client_list():
        if sub_info.name == sub_name:
            break
    assert sub_info.name == sub_name

    mpsc = Receiver(loop=loop)
    await sub.subscribe(mpsc.channel('channel:1'))
    await sub.subscribe(mpsc.channel('channel:2'))
    await sub.psubscribe(mpsc.pattern('channel:*'))

    q = asyncio.Queue(loop=loop)
    EOF = object()

    async def reader():
        async for ch, msg in mpsc.iter(encoding='utf-8'):
            await q.put((ch.name, msg))
        await q.put(EOF)

    tsk = asyncio.ensure_future(reader(), loop=loop)
    await pub.publish_json('channel:1', ['hello'])
    await pub.publish_json('channel:2', ['hello'])
    # receive all messages
    assert await q.get() == (b'channel:1', '["hello"]')
    assert await q.get() == (b'channel:*', (b'channel:1', '["hello"]'))
    assert await q.get() == (b'channel:2', '["hello"]')
    assert await q.get() == (b'channel:*', (b'channel:2', '["hello"]'))

    # XXX: need to implement `client kill`
    assert await pub.execute('client', 'kill', sub_info.addr) in (b'OK', 1)
    await asyncio.wait_for(tsk, timeout=1, loop=loop)
    assert await q.get() is EOF
Example #16
0
async def test_pubsub_receiver_stop_on_disconnect(create_redis, server):
    pub = await create_redis(server.tcp_address)
    sub = await create_redis(server.tcp_address)
    sub_name = "sub-{:X}".format(id(sub))
    await sub.client_setname(sub_name)
    for sub_info in await pub.client_list():
        if sub_info.name == sub_name:
            break
    assert sub_info.name == sub_name

    mpsc = Receiver()
    await sub.subscribe(mpsc.channel("channel:1"))
    await sub.subscribe(mpsc.channel("channel:2"))
    await sub.psubscribe(mpsc.pattern("channel:*"))

    q = asyncio.Queue()
    EOF = object()

    async def reader():
        async for ch, msg in mpsc.iter(encoding="utf-8"):
            await q.put((ch.name, msg))
        await q.put(EOF)

    tsk = asyncio.ensure_future(reader())
    await pub.publish_json("channel:1", ["hello"])
    await pub.publish_json("channel:2", ["hello"])
    # receive all messages
    assert await q.get() == (b"channel:1", '["hello"]')
    assert await q.get() == (b"channel:*", (b"channel:1", '["hello"]'))
    assert await q.get() == (b"channel:2", '["hello"]')
    assert await q.get() == (b"channel:*", (b"channel:2", '["hello"]'))

    # XXX: need to implement `client kill`
    assert await pub.execute("client", "kill", sub_info.addr) in (b"OK", 1)
    await asyncio.wait_for(tsk, timeout=1)
    assert await q.get() is EOF
Example #17
0
async def test_pubsub_receiver_stop_on_disconnect(create_redis, server, loop):
    pub = await create_redis(server.tcp_address, loop=loop)
    sub = await create_redis(server.tcp_address, loop=loop)
    sub_name = 'sub-{:X}'.format(id(sub))
    await sub.client_setname(sub_name)
    for sub_info in await pub.client_list():
        if sub_info.name == sub_name:
            break
    assert sub_info.name == sub_name

    mpsc = Receiver(loop=loop)
    await sub.subscribe(mpsc.channel('channel:1'))
    await sub.subscribe(mpsc.channel('channel:2'))
    await sub.psubscribe(mpsc.pattern('channel:*'))

    q = asyncio.Queue(loop=loop)
    EOF = object()

    async def reader():
        async for ch, msg in mpsc.iter(encoding='utf-8'):
            await q.put((ch.name, msg))
        await q.put(EOF)

    tsk = asyncio.ensure_future(reader(), loop=loop)
    await pub.publish_json('channel:1', ['hello'])
    await pub.publish_json('channel:2', ['hello'])
    # receive all messages
    assert await q.get() == (b'channel:1', '["hello"]')
    assert await q.get() == (b'channel:*', (b'channel:1', '["hello"]'))
    assert await q.get() == (b'channel:2', '["hello"]')
    assert await q.get() == (b'channel:*', (b'channel:2', '["hello"]'))

    # XXX: need to implement `client kill`
    assert await pub.execute('client', 'kill', sub_info.addr) in (b'OK', 1)
    await asyncio.wait_for(tsk, timeout=1, loop=loop)
    assert await q.get() is EOF
Example #18
0
class WebsocketServer:
    """Provide websocket proxy to public redis channels and hashes.

    In addition to passing through data this server supports some geo
    operations, see the WebsocketHandler and protocol.CommandsMixin for
    details.
    """

    handler_class = WebsocketHandler

    def __init__(self,
                 redis,
                 subscriber,
                 read_timeout=None,
                 keep_alive_timeout=None):
        """Set default values for new WebsocketHandlers.

        :param redis: aioredis.StrictRedis instance
        :param subscriber: aioredis.StrictRedis instance
        :param read_timeout: Timeout, after which the websocket connection is
           checked and kept if still open (does not cancel an open connection)
        :param keep_alive_timeout: Time after which the server cancels the
           handler task (independently of it's internal state)
        """

        self.read_timeout = read_timeout
        self.keep_alive_timeout = keep_alive_timeout
        self.receiver = Receiver()
        self.handlers = {}
        self.redis = redis
        self.subscriber = subscriber

    async def websocket_handler(self, websocket, path):
        """Return handler for a single websocket connection."""

        logger.info("Client %s connected", websocket.remote_address)
        handler = await self.handler_class.create(
            self.redis,
            websocket,
            set(map(bytes.decode, self.receiver.channels.keys())),
            set(map(bytes.decode, self.receiver.patterns.keys())),
            read_timeout=self.read_timeout,
        )
        self.handlers[websocket.remote_address] = handler
        try:
            await asyncio.wait_for(handler.listen(), self.keep_alive_timeout)
        finally:
            del self.handlers[websocket.remote_address]
            await handler.close()
            logger.info("Client %s removed", websocket.remote_address)

    async def redis_subscribe(self, channel_names=None, channel_patterns=None):
        """Subscribe to channels by channel_names and/or channel_patterns."""

        if not (channel_names or channel_patterns):
            raise ValueError("Got nothing to subscribe to")

        if channel_names:
            await self.subscriber.subscribe(*(self.receiver.channel(name)
                                              for name in channel_names))

        if channel_patterns:
            await self.subscriber.psubscribe(*(self.receiver.pattern(pattern)
                                               for pattern in channel_patterns)
                                             )

    async def redis_reader(self):
        """Pass messages from subscribed channels to handlers."""

        async for channel, msg in self.receiver.iter(encoding="utf-8"):
            if channel.is_pattern:
                channel_name, msg = msg[0].decode(), msg[1]
            else:
                channel_name = channel.name.decode()

            for handler in self.handlers.values():
                if channel_name in handler.subscriptions:
                    handler.queue.put_nowait(
                        Message(source=channel_name, content=msg))

    def listen(self,
               host,
               port,
               channel_names=None,
               channel_patterns=None,
               loop=None):
        """Listen for websocket connections and manage redis subscriptions."""

        loop = loop or asyncio.get_event_loop()
        start_server = serve(self.websocket_handler, host, port)
        loop.run_until_complete(
            self.redis_subscribe(channel_names, channel_patterns))
        loop.run_until_complete(start_server)
        logger.info("Listening on %s:%s...", host, port)
        loop.run_until_complete(self.redis_reader())
Example #19
0
    async def _update_actors(self):
        # TODO(fyrestone): Refactor code for updating actor / node / job.
        # Subscribe actor channel.
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = "{}:*".format(actor_consts.ACTOR_CHANNEL)
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        def _process_actor_table_data(data):
            actor_class = actor_classname_from_task_spec(
                data.get("taskSpec", {}))
            data["actorClass"] = actor_class

        # Get all actor info.
        while True:
            try:
                logger.info("Getting all actor info from GCS.")
                request = gcs_service_pb2.GetAllActorInfoRequest()
                reply = await self._gcs_actor_info_stub.GetAllActorInfo(
                    request, timeout=5)
                if reply.status.code == 0:
                    actors = {}
                    for message in reply.actor_table_data:
                        actor_table_data = actor_table_data_to_dict(message)
                        _process_actor_table_data(actor_table_data)
                        actors[actor_table_data["actorId"]] = actor_table_data
                    # Update actors.
                    DataSource.actors.reset(actors)
                    # Update node actors and job actors.
                    job_actors = {}
                    node_actors = {}
                    for actor_id, actor_table_data in actors.items():
                        job_id = actor_table_data["jobId"]
                        node_id = actor_table_data["address"]["rayletId"]
                        job_actors.setdefault(job_id,
                                              {})[actor_id] = actor_table_data
                        # Update only when node_id is not Nil.
                        if node_id != actor_consts.NIL_NODE_ID:
                            node_actors.setdefault(
                                node_id, {})[actor_id] = actor_table_data
                    DataSource.job_actors.reset(job_actors)
                    DataSource.node_actors.reset(node_actors)
                    logger.info("Received %d actor info from GCS.",
                                len(actors))
                    break
                else:
                    raise Exception(
                        f"Failed to GetAllActorInfo: {reply.status.message}")
            except Exception:
                logger.exception("Error Getting all actor info from GCS.")
                await asyncio.sleep(
                    actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)

        # Receive actors from channel.
        state_keys = ("state", "address", "numRestarts", "timestamp", "pid")
        async for sender, msg in receiver.iter():
            try:
                actor_id, actor_table_data = msg
                pubsub_message = ray.gcs_utils.PubSubMessage.FromString(
                    actor_table_data)
                message = ray.gcs_utils.ActorTableData.FromString(
                    pubsub_message.data)
                actor_table_data = actor_table_data_to_dict(message)
                _process_actor_table_data(actor_table_data)
                # If actor is not new registered but updated, we only update
                # states related fields.
                if actor_table_data["state"] != "DEPENDENCIES_UNREADY":
                    actor_id = actor_id.decode(
                        "UTF-8")[len(ray.gcs_utils.TablePrefix_ACTOR_string +
                                     ":"):]
                    actor_table_data_copy = dict(DataSource.actors[actor_id])
                    for k in state_keys:
                        actor_table_data_copy[k] = actor_table_data[k]
                    actor_table_data = actor_table_data_copy
                actor_id = actor_table_data["actorId"]
                job_id = actor_table_data["jobId"]
                node_id = actor_table_data["address"]["rayletId"]
                # Update actors.
                DataSource.actors[actor_id] = actor_table_data
                # Update node actors (only when node_id is not Nil).
                if node_id != actor_consts.NIL_NODE_ID:
                    node_actors = dict(DataSource.node_actors.get(node_id, {}))
                    node_actors[actor_id] = actor_table_data
                    DataSource.node_actors[node_id] = node_actors
                # Update job actors.
                job_actors = dict(DataSource.job_actors.get(job_id, {}))
                job_actors[actor_id] = actor_table_data
                DataSource.job_actors[job_id] = job_actors
            except Exception:
                logger.exception("Error receiving actor info.")
Example #20
0
    async def _update_actors(self):
        # Subscribe actor channel.
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = "{}:*".format(stats_collector_consts.ACTOR_CHANNEL)
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        def _process_actor_table_data(data):
            actor_class = actor_classname_from_task_spec(
                data.get("taskSpec", {}))
            data["actorClass"] = actor_class

        # Get all actor info.
        while True:
            try:
                logger.info("Getting all actor info from GCS.")
                request = gcs_service_pb2.GetAllActorInfoRequest()
                reply = await self._gcs_actor_info_stub.GetAllActorInfo(
                    request, timeout=5)
                if reply.status.code == 0:
                    actors = {}
                    for message in reply.actor_table_data:
                        actor_table_data = actor_table_data_to_dict(message)
                        _process_actor_table_data(actor_table_data)
                        actors[actor_table_data["actorId"]] = actor_table_data
                    # Update actors.
                    DataSource.actors.reset(actors)
                    # Update node actors and job actors.
                    job_actors = {}
                    node_actors = {}
                    for actor_id, actor_table_data in actors.items():
                        job_id = actor_table_data["jobId"]
                        node_id = actor_table_data["address"]["rayletId"]
                        job_actors.setdefault(job_id,
                                              {})[actor_id] = actor_table_data
                        node_actors.setdefault(node_id,
                                               {})[actor_id] = actor_table_data
                    DataSource.job_actors.reset(job_actors)
                    DataSource.node_actors.reset(node_actors)
                    logger.info("Received %d actor info from GCS.",
                                len(actors))
                    break
                else:
                    raise Exception(
                        f"Failed to GetAllActorInfo: {reply.status.message}")
            except Exception:
                logger.exception("Error Getting all actor info from GCS.")
                await asyncio.sleep(stats_collector_consts.
                                    RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)

        # Receive actors from channel.
        async for sender, msg in receiver.iter():
            try:
                _, actor_table_data = msg
                pubsub_message = ray.gcs_utils.PubSubMessage.FromString(
                    actor_table_data)
                message = ray.gcs_utils.ActorTableData.FromString(
                    pubsub_message.data)
                actor_table_data = actor_table_data_to_dict(message)
                _process_actor_table_data(actor_table_data)
                actor_id = actor_table_data["actorId"]
                job_id = actor_table_data["jobId"]
                node_id = actor_table_data["address"]["rayletId"]
                # Update actors.
                DataSource.actors[actor_id] = actor_table_data
                # Update node actors.
                node_actors = dict(DataSource.node_actors.get(node_id, {}))
                node_actors[actor_id] = actor_table_data
                DataSource.node_actors[node_id] = node_actors
                # Update job actors.
                job_actors = dict(DataSource.job_actors.get(job_id, {}))
                job_actors[actor_id] = actor_table_data
                DataSource.job_actors[job_id] = job_actors
            except Exception:
                logger.exception("Error receiving actor info.")
from aioredis.pubsub import Receiver
from aioredis.abc import AbcChannel


mpsc = Receiver(loop=loop)

async def reader(mpsc):
    async for channel, msg in mpsc.iter():
        assert isinstance(channel, AbcChannel)
        print("Got {!r} in channel {!r}".format(msg, channel))


asyncio.ensure_future(reader(mpsc))
await redis.subscribe(mpsc.channel('channel:1'),
                      mpsc.channel('channel:3'),
                      mpsc.channel('channel:5'))
await redis.psubscribe(mpsc.pattern('hello'))
# publishing 'Hello world' into 'hello-channel'
# will print this message:

# when all is done:
await redis.unsubscribe('channel:1', 'channel:3', 'channel:5')
await redis.punsubscribe('hello')
mpsc.stop()
# any message received after stop() will be ignored.