Beispiel #1
0
    async def _update_error_info(self):
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = ray.gcs_utils.RAY_ERROR_PUBSUB_PATTERN
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        async for sender, msg in receiver.iter():
            try:
                _, data = msg
                pubsub_msg = ray.gcs_utils.PubSubMessage.FromString(data)
                error_data = ray.gcs_utils.ErrorTableData.FromString(
                    pubsub_msg.data)
                message = error_data.error_message
                message = re.sub(r"\x1b\[\d+m", "", message)
                match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
                if match:
                    pid = match.group(1)
                    ip = match.group(2)
                    errs_for_ip = DataSource.ip_and_pid_to_errors.get(ip, {})
                    pid_errors = errs_for_ip.get(pid, [])
                    pid_errors.append({
                        "message": message,
                        "timestamp": error_data.timestamp,
                        "type": error_data.type
                    })
                    errs_for_ip[pid] = pid_errors
                    DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
                    logger.info(f"Received error entry for {ip} {pid}")
            except Exception:
                logger.exception("Error receiving error info.")
Beispiel #2
0
    async def _update_log_info(self):
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        channel = receiver.channel(gcs_utils.LOG_FILE_CHANNEL)
        await aioredis_client.subscribe(channel)
        logger.info("Subscribed to %s", channel)

        async for sender, msg in receiver.iter():
            try:
                data = json.loads(ray._private.utils.decode(msg))
                ip = data["ip"]
                pid = str(data["pid"])
                if pid != "autoscaler":
                    logs_for_ip = dict(
                        DataSource.ip_and_pid_to_logs.get(ip, {}))
                    logs_for_pid = list(logs_for_ip.get(pid, []))
                    logs_for_pid.extend(data["lines"])

                    # Only cache upto MAX_LOGS_TO_CACHE
                    logs_length = len(logs_for_pid)
                    if logs_length > MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD:
                        offset = logs_length - MAX_LOGS_TO_CACHE
                        del logs_for_pid[:offset]

                    logs_for_ip[pid] = logs_for_pid
                    DataSource.ip_and_pid_to_logs[ip] = logs_for_ip
                logger.info(f"Received a log for {ip} and {pid}")
            except Exception:
                logger.exception("Error receiving log info.")
Beispiel #3
0
 async def run(self, loop, **kwargs):
     aredis = await create_aredis()
     mpsc = Receiver()
     pattern = mpsc.pattern("TICK:*")
     await aredis.psubscribe(pattern)
     logger.info("subscribed")
     async for channel, data in mpsc.iter():
         data = ujson.loads(data[1])
         self.handle_tick(data)
Beispiel #4
0
    async def _update_jobs(self):
        # Subscribe job channel.
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = f"{job_consts.JOB_CHANNEL}:*"
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        # Get all job info.
        while True:
            try:
                logger.info("Getting all job info from GCS.")
                request = gcs_service_pb2.GetAllJobInfoRequest()
                reply = await self._gcs_job_info_stub.GetAllJobInfo(request,
                                                                    timeout=5)
                if reply.status.code == 0:
                    jobs = {}
                    for job_table_data in reply.job_info_list:
                        data = job_table_data_to_dict(job_table_data)
                        jobs[data["jobId"]] = data
                    # Update jobs.
                    DataSource.jobs.reset(jobs)
                    logger.info("Received %d job info from GCS.", len(jobs))
                    break
                else:
                    raise Exception(
                        f"Failed to GetAllJobInfo: {reply.status.message}")
            except Exception:
                logger.exception("Error Getting all job info from GCS.")
                await asyncio.sleep(
                    job_consts.RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS)

        # Receive jobs from channel.
        async for sender, msg in receiver.iter():
            try:
                _, data = msg
                pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data)
                message = ray.gcs_utils.JobTableData.FromString(
                    pubsub_message.data)
                job_id = ray._raylet.JobID(message.job_id)
                if job_id.is_submitted_from_dashboard():
                    job_table_data = job_table_data_to_dict(message)
                    job_id = job_table_data["jobId"]
                    # Update jobs.
                    DataSource.jobs[job_id] = job_table_data
                else:
                    logger.info(
                        "Ignore job %s which is not submitted from dashboard.",
                        job_id.hex())
            except Exception:
                logger.exception("Error receiving job info.")
Beispiel #5
0
    async def _event_listener(self, channel):

        async for _ in AsyncCirculator():

            async with self._redis_pool.get_client() as cache:

                receiver = Receiver()

                await cache.subscribe(receiver.channel(channel))

                async for channel, message in receiver.iter():
                    await self._event_assigner(channel, message)
Beispiel #6
0
    async def _update_error_info(self):
        def process_error(error_data):
            message = error_data.error_message
            message = re.sub(r"\x1b\[\d+m", "", message)
            match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
            if match:
                pid = match.group(1)
                ip = match.group(2)
                errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
                pid_errors = list(errs_for_ip.get(pid, []))
                pid_errors.append(
                    {
                        "message": message,
                        "timestamp": error_data.timestamp,
                        "type": error_data.type,
                    }
                )
                errs_for_ip[pid] = pid_errors
                DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
                logger.info(f"Received error entry for {ip} {pid}")

        if self._dashboard_head.gcs_error_subscriber:
            while True:
                try:
                    (
                        _,
                        error_data,
                    ) = await self._dashboard_head.gcs_error_subscriber.poll()
                    if error_data is None:
                        continue
                    process_error(error_data)
                except Exception:
                    logger.exception("Error receiving error info from GCS.")
        else:
            aioredis_client = self._dashboard_head.aioredis_client
            receiver = Receiver()

            key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
            pattern = receiver.pattern(key)
            await aioredis_client.psubscribe(pattern)
            logger.info("Subscribed to %s", key)

            async for _, msg in receiver.iter():
                try:
                    _, data = msg
                    pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
                    error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
                    process_error(error_data)
                except Exception:
                    logger.exception("Error receiving error info from Redis.")
Beispiel #7
0
    async def run(self):
        p = self._dashboard_head.aioredis_client
        mpsc = Receiver()

        reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
        await p.psubscribe(mpsc.pattern(reporter_key))
        logger.info("Subscribed to {}".format(reporter_key))

        async for sender, msg in mpsc.iter():
            try:
                _, data = msg
                data = json.loads(ray.utils.decode(data))
                DataSource.node_physical_stats[data["ip"]] = data
            except Exception as ex:
                logger.exception(ex)
Beispiel #8
0
    async def run(self, server):
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
        await aioredis_client.psubscribe(receiver.pattern(reporter_key))
        logger.info(f"Subscribed to {reporter_key}")

        async for sender, msg in receiver.iter():
            try:
                _, data = msg
                data = json.loads(ray.utils.decode(data))
                DataSource.node_physical_stats[data["ip"]] = data
            except Exception:
                logger.exception(
                    "Error receiving node physical stats from reporter agent.")
async def test_pubsub_receiver_call_stop_with_empty_queue(
        create_redis, server, loop):
    sub = await create_redis(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)

    # FIXME: currently at least one subscriber is needed
    snd1, = await sub.subscribe(mpsc.channel('chan:1'))

    now = loop.time()
    loop.call_later(.5, mpsc.stop)
    async for i in mpsc.iter():  # noqa (flake8 bug with async for)
        assert False, "StopAsyncIteration not raised"
    dt = loop.time() - now
    assert dt <= 1.5
    assert not mpsc.is_active
Beispiel #10
0
 async def listen_redis(self, websocket, channels):
     channel = channels.pop(0)
     try:
         mpsc = Receiver(loop=asyncio.get_event_loop())
         await self.sub.subscribe(
             mpsc.channel(channel),
             *(mpsc.channel(channel) for channel in channels))
         async for channel, msg in mpsc.iter():
             if websocket.client_state == WebSocketState.CONNECTED:
                 await websocket.send_bytes(msg)
     except:
         import traceback
         traceback.print_exc()
         await self.close_connections()
     finally:
         logger.info('Connection closed')
Beispiel #11
0
async def test_pubsub_receiver_call_stop_with_empty_queue(
        create_redis, server, loop):
    sub = await create_redis(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)

    # FIXME: currently at least one subscriber is needed
    snd1, = await sub.subscribe(mpsc.channel('chan:1'))

    now = loop.time()
    loop.call_later(.5, mpsc.stop)
    async for i in mpsc.iter():  # noqa (flake8 bug with async for)
        assert False, "StopAsyncIteration not raised"
    dt = loop.time() - now
    assert dt <= 1.5
    assert not mpsc.is_active
Beispiel #12
0
async def subscribe_handler(_request, ws):
    data = await ws.recv()
    try:
        decoded_data = ujson.decode(data)
    except ValueError:
        await ws.send(ujson.dumps({'error': 'invalid json data'}))
        return

    token = decoded_data.get('token', '')

    loop = asyncio.get_event_loop()
    redis = await storage.get_async_redis(loop)

    mpsc = Receiver(loop=loop)
    await redis.subscribe(mpsc.channel(f'updates:{token}'))
    async for channel, msg in mpsc.iter():
        await ws.send(msg.decode())
Beispiel #13
0
    async def _update_actors(self):
        # Subscribe actor channel.
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = "{}:*".format(stats_collector_consts.ACTOR_CHANNEL)
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        # Get all actor info.
        while True:
            try:
                logger.info("Getting all actor info from GCS.")
                request = gcs_service_pb2.GetAllActorInfoRequest()
                reply = await self._gcs_actor_info_stub.GetAllActorInfo(
                    request, timeout=2)
                if reply.status.code == 0:
                    result = {}
                    for actor_info in reply.actor_table_data:
                        result[binary_to_hex(actor_info.actor_id)] = \
                            actor_table_data_to_dict(actor_info)
                    DataSource.actors.reset(result)
                    logger.info("Received %d actor info from GCS.",
                                len(result))
                    break
                else:
                    raise Exception(
                        f"Failed to GetAllActorInfo: {reply.status.message}")
            except Exception:
                logger.exception("Error Getting all actor info from GCS.")
                await asyncio.sleep(stats_collector_consts.
                                    RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)

        # Receive actors from channel.
        async for sender, msg in receiver.iter():
            try:
                _, data = msg
                pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data)
                actor_info = ray.gcs_utils.ActorTableData.FromString(
                    pubsub_message.data)
                DataSource.actors[binary_to_hex(actor_info.actor_id)] = \
                    actor_table_data_to_dict(actor_info)
            except Exception:
                logger.exception("Error receiving actor info.")
Beispiel #14
0
 async def run(self, loop, **kwargs):
     redis = await create_aredis()
     self.redis = await create_aredis()
     recv = Receiver()
     await redis.psubscribe(recv.pattern("TICK:*"), recv.pattern("BAR:*"))
     async for _, (channel, data) in recv.iter():
         if isinstance(channel, bytes):
             channel = channel.decode()
         channel = channel.split(":")
         d = ujson.loads(data)
         if channel[0] == "BAR":
             freq, instrument = channel[1:]
             key = f"ZSET:BAR:{freq}:{instrument}"
             score = d['timestamp']
         else:
             instrument = channel[1]
             key = f"ZSET:TICK:{instrument}"
             score = d['Timestamp']
         await self.redis.zadd(key, score, data)
Beispiel #15
0
    async def run(self, server):
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
        await aioredis_client.psubscribe(receiver.pattern(reporter_key))
        logger.info(f"Subscribed to {reporter_key}")

        async for sender, msg in receiver.iter():
            try:
                # The key is b'RAY_REPORTER:{node id hex}',
                # e.g. b'RAY_REPORTER:2b4fbd406898cc86fb88fb0acfd5456b0afd87cf'
                key, data = msg
                data = json.loads(ray._private.utils.decode(data))
                key = key.decode("utf-8")
                node_id = key.split(":")[-1]
                DataSource.node_physical_stats[node_id] = data
            except Exception:
                logger.exception(
                    "Error receiving node physical stats from reporter agent.")
Beispiel #16
0
    async def run(self, server):
        # Need daemon True to avoid dashboard hangs at exit.
        self.service_discovery.daemon = True
        self.service_discovery.start()
        if gcs_pubsub_enabled():
            gcs_addr = self._dashboard_head.gcs_address
            subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
            await subscriber.subscribe()

            while True:
                try:
                    # The key is b'RAY_REPORTER:{node id hex}',
                    # e.g. b'RAY_REPORTER:2b4fbd...'
                    key, data = await subscriber.poll()
                    if key is None:
                        continue
                    data = json.loads(data)
                    node_id = key.split(":")[-1]
                    DataSource.node_physical_stats[node_id] = data
                except Exception:
                    logger.exception("Error receiving node physical stats "
                                     "from reporter agent.")
        else:
            from aioredis.pubsub import Receiver

            receiver = Receiver()
            aioredis_client = self._dashboard_head.aioredis_client
            reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
            await aioredis_client.psubscribe(receiver.pattern(reporter_key))
            logger.info(f"Subscribed to {reporter_key}")

            async for sender, msg in receiver.iter():
                try:
                    key, data = msg
                    data = json.loads(ray._private.utils.decode(data))
                    key = key.decode("utf-8")
                    node_id = key.split(":")[-1]
                    DataSource.node_physical_stats[node_id] = data
                except Exception:
                    logger.exception("Error receiving node physical stats "
                                     "from reporter agent.")
Beispiel #17
0
    async def _update_log_info(self):
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        channel = receiver.channel(ray.gcs_utils.LOG_FILE_CHANNEL)
        await aioredis_client.subscribe(channel)
        logger.info("Subscribed to %s", channel)

        async for sender, msg in receiver.iter():
            try:
                data = json.loads(ray._private.utils.decode(msg))
                ip = data["ip"]
                pid = str(data["pid"])
                logs_for_ip = dict(DataSource.ip_and_pid_to_logs.get(ip, {}))
                logs_for_pid = list(logs_for_ip.get(pid, []))
                logs_for_pid.extend(data["lines"])
                logs_for_ip[pid] = logs_for_pid
                DataSource.ip_and_pid_to_logs[ip] = logs_for_ip
                logger.info(f"Received a log for {ip} and {pid}")
            except Exception:
                logger.exception("Error receiving log info.")
Beispiel #18
0
class WebsocketServer:

    """Provide websocket proxy to public redis channels and hashes.

    In addition to passing through data this server supports some geo
    operations, see the WebsocketHandler and protocol.CommandsMixin for
    details.
    """

    handler_class = WebsocketHandler

    def __init__(self, redis, subscriber,
                 read_timeout=None, keep_alive_timeout=None):
        """Set default values for new WebsocketHandlers.

        :param redis: aioredis.StrictRedis instance
        :param subscriber: aioredis.StrictRedis instance
        :param read_timeout: Timeout, after which the websocket connection is
           checked and kept if still open (does not cancel an open connection)
        :param keep_alive_timeout: Time after which the server cancels the
           handler task (independently of it's internal state)
        """

        self.read_timeout = read_timeout
        self.keep_alive_timeout = keep_alive_timeout
        self.receiver = Receiver()
        self.handlers = {}
        self.redis = redis
        self.subscriber = subscriber

    async def websocket_handler(self, websocket, path):
        """Return handler for a single websocket connection."""

        logger.info("Client %s connected", websocket.remote_address)
        handler = await self.handler_class.create(
            self.redis, websocket, set(map(
                bytes.decode, self.receiver.channels.keys())),
            read_timeout=self.read_timeout)
        self.handlers[websocket.remote_address] = handler
        try:
            await asyncio.wait_for(handler.listen(), self.keep_alive_timeout)
        finally:
            del self.handlers[websocket.remote_address]
            await handler.close()
            logger.info("Client %s removed", websocket.remote_address)

    async def redis_subscribe(self, channel_names):
        """Subscribe to all channels in channel_names once."""

        await self.subscriber.subscribe(*(self.receiver.channel(channel_name)
                                        for channel_name in channel_names))

    async def redis_reader(self):
        """Pass messages from subscribed channels to handlers."""

        async for channel, msg in self.receiver.iter(encoding='utf-8'):
            for handler in self.handlers.values():
                if channel.name.decode() in handler.subscriptions:
                    handler.queue.put_nowait(Message(
                        source=channel.name.decode(), content=msg))

    def listen(self, host, port, channel_names, loop=None):
        """Listen for websocket connections and manage redis subscriptions."""

        loop = loop or asyncio.get_event_loop()
        start_server = serve(self.websocket_handler, host, port)
        loop.run_until_complete(self.redis_subscribe(channel_names))
        loop.run_until_complete(start_server)
        logger.info("Listening on %s:%s...", host, port)
        loop.run_until_complete(self.redis_reader())
Beispiel #19
0
class DashLink(BaseCog):
    def __init__(self, bot):
        super().__init__(bot)
        bot.loop.create_task(self.init())
        self.redis_link: aioredis.Redis = None
        self.receiver = Receiver(loop=bot.loop)
        self.handlers = dict(question=self.question,
                             update=self.update,
                             user_guilds=self.user_guilds,
                             user_guilds_end=self.user_guilds_end,
                             guild_info_watch=self.guild_info_watch,
                             guild_info_watch_end=self.guild_info_watch_end)
        self.question_handlers = dict(
            heartbeat=self.still_spinning,
            user_info=self.user_info_request,
            get_guild_settings=self.get_guild_settings,
            save_guild_settings=self.save_guild_settings,
            replace_guild_settings=self.replace_guild_settings,
            setup_mute=self.setup_mute,
            cleanup_mute=self.cleanup_mute,
            cache_info=self.cache_info,
            guild_user_perms=self.guild_user_perms)
        # The last time we received a heartbeat, the current attempt number, how many times we have notified the owner
        self.last_dash_heartbeat = [time.time(), 0, 0]
        self.last_update = datetime.now()
        self.to_log = dict()
        self.update_message = None

        if Configuration.get_master_var(
                "TRANSLATIONS",
                dict(SOURCE="SITE", CHANNEL=0, KEY="", LOGIN="",
                     WEBROOT=""))["SOURCE"] == 'CROWDIN':
            self.handlers["crowdin_webhook"] = self.crowdin_webhook
        self.task = self._receiver()

    def cog_unload(self):
        self.bot.loop.create_task(self._unload())

    async def _unload(self):
        for c in self.receiver.channels.values():
            self.redis_link.unsubscribe(c)
        self.receiver.stop()
        self.redis_link.close()
        await self.redis_link.wait_closed()

    async def init(self):
        try:
            self.redis_link = await aioredis.create_redis_pool(
                (Configuration.get_master_var('REDIS_HOST', "localhost"),
                 Configuration.get_master_var('REDIS_PORT', 6379)),
                encoding="utf-8",
                db=0,
                maxsize=2)  # size 2: one send, one receive
            self.bot.loop.create_task(self._receiver())

            if Configuration.get_master_var("DASH_OUTAGE")["outage_detection"]:
                self.bot.loop.create_task(self.dash_monitor())

            await self.redis_link.subscribe(
                self.receiver.channel("dash-bot-messages"))
            await self.redis_link.publish_json(
                "bot-dash-messages", {
                    'type': 'cache_info',
                    'message': await self.cache_info()
                })

        except OSError:
            await GearbotLogging.bot_log("Failed to connect to the dash!")

    async def dash_monitor(self):
        DASH_OUTAGE_INFO: dict = Configuration.get_master_var("DASH_OUTAGE")
        DASH_OUTAGE_CHANNEl = DASH_OUTAGE_INFO["dash_outage_channel"]
        MAX_WARNINGS = DASH_OUTAGE_INFO["max_bot_outage_warnings"]
        BOT_OUTAGE_PINGED_ROLES = DASH_OUTAGE_INFO["dash_outage_pinged_roles"]

        while True:
            if (time.time() - self.last_dash_heartbeat[0]) > 5:
                self.last_dash_heartbeat[1] += 1

                if self.last_dash_heartbeat[
                        1] >= 3 and self.last_dash_heartbeat[2] < MAX_WARNINGS:
                    print(
                        "The dashboard API keepalive hasn't responded in over 3 minutes!"
                    )

                    self.last_dash_heartbeat[2] += 1
                    self.last_dash_heartbeat[1] = 0

                    if DASH_OUTAGE_CHANNEl:
                        outage_message = DASH_OUTAGE_INFO["dash_outage_embed"]

                        # Apply the timestamp
                        outage_message["timestamp"] = datetime.now().isoformat(
                        )

                        # Set the color to the format Discord understands
                        outage_message["color"] = outage_message["color"]

                        # Generate the custom message and role pings
                        notify_message = DASH_OUTAGE_INFO[
                            "dash_outage_message"]
                        if BOT_OUTAGE_PINGED_ROLES:
                            pinged_roles = []
                            for role_id in BOT_OUTAGE_PINGED_ROLES:
                                pinged_roles.append(f"<@&{role_id}>")

                            notify_message += f" Pinging: {', '.join(pinged_roles)}"

                        try:
                            outage_channel = self.bot.get_channel(
                                DASH_OUTAGE_CHANNEl)
                            await outage_channel.send(
                                notify_message,
                                embed=Embed.from_dict(outage_message))
                        except Forbidden:
                            GearbotLogging.error(
                                "We couldn't access the specified channel, the notification will not be sent!"
                            )

            # Wait a little bit longer so the dashboard has a chance to update before we check
            await asyncio.sleep(65)

    async def _handle(self, sender, message):
        try:
            await self.handlers[message["type"]](message["message"])
        except CancelledError:
            return
        except Exception as e:
            await TheRealGearBot.handle_exception("Dash message handling",
                                                  self.bot, e, None, None,
                                                  None, message)

    async def send_to_dash(self, channel, **kwargs):
        await self.redis_link.publish_json("bot-dash-messages",
                                           dict(type=channel, message=kwargs))

    async def question(self, message):
        try:
            reply = dict(reply=await self.question_handlers[message["type"]]
                         (message["data"]),
                         state="OK",
                         uid=message["uid"])
        except UnauthorizedException:
            reply = dict(uid=message["uid"], state="Unauthorized")
        except ValidationException as ex:
            reply = dict(uid=message["uid"],
                         state="Bad Request",
                         errors=ex.errors)
        except CancelledError:
            return
        except Exception as ex:
            reply = dict(uid=message["uid"], state="Failed")
            await self.send_to_dash("reply", **reply)
            raise ex
        await self.send_to_dash("reply", **reply)

    async def _receiver(self):
        async for sender, message in self.receiver.iter(encoding='utf-8',
                                                        decoder=json.loads):
            self.bot.loop.create_task(self._handle(sender, message))

    async def still_spinning(self, _):
        self.last_dash_heartbeat[0] = time.time()
        self.last_dash_heartbeat[1] = 0
        self.last_dash_heartbeat[2] = 0

        return self.bot.latency

    async def user_info_request(self, message):
        user_id = message["user_id"]
        user_info = await self.bot.fetch_user(user_id)
        return_info = {
            "username":
            user_info.name,
            "discrim":
            user_info.discriminator,
            "avatar_url":
            str(user_info.avatar_url_as(size=256)),
            "bot_admin_status":
            await self.bot.is_owner(user_info)
            or user_id in Configuration.get_master_var("BOT_ADMINS", [])
        }

        return return_info

    async def user_guilds(self, message):
        user_id = int(message["user_id"])
        self.bot.dash_guild_users.add(user_id)
        self.redis_link.publish_json(
            "bot-dash-messages",
            dict(type="guild_add",
                 message=dict(user_id=user_id,
                              guilds=DashUtils.get_user_guilds(
                                  self.bot, user_id))))

    async def user_guilds_end(self, message):
        user_id = int(message["user_id"])
        self.bot.dash_guild_users.remove(user_id)

    async def guild_user_perms(self, message):
        guild = self.bot.get_guild(int(message["guild_id"]))
        if guild is None:
            return 0
        return DashUtils.get_guild_perms(
            guild.get_member(int(message["user_id"])))

    @needs_perm(DASH_PERMS.ACCESS)
    async def guild_info_watch(self, message):
        # start tracking info
        guild_id, user_id = get_info(message)
        if guild_id not in self.bot.dash_guild_watchers:
            self.bot.dash_guild_watchers[guild_id] = set()
        self.bot.dash_guild_watchers[guild_id].add(user_id)
        await self.send_guild_info(
            self.bot.get_guild(guild_id).get_member(user_id))

    async def guild_info_watch_end(self, message):
        guild_id, user_id = get_info(message)
        if guild_id in self.bot.dash_guild_watchers:
            users = self.bot.dash_guild_watchers[guild_id]
            users.remove(user_id)
            if len(users) is 0:
                del self.bot.dash_guild_watchers[guild_id]

    async def send_guild_info_update_to_all(self, guild):
        if guild.id in self.bot.dash_guild_watchers:
            for user in self.bot.dash_guild_watchers[guild.id]:
                await self.send_guild_info(guild.get_member(user))

    async def send_guild_info(self, member):
        await self.send_to_dash("guild_update",
                                user_id=member.id,
                                guild_id=member.guild.id,
                                info=DashUtils.assemble_guild_info(
                                    self.bot, member))

    @needs_perm(DASH_PERMS.VIEW_CONFIG)
    async def get_guild_settings(self, message):
        section = Configuration.get_var(int(message["guild_id"]),
                                        message["section"])
        section = {
            k: [str(rid) if isinstance(rid, int) else rid
                for rid in v] if isinstance(v, list) else
            str(v) if isinstance(v, int) and not isinstance(v, bool) else v
            for k, v in section.items()
        }
        return section

    @needs_perm(DASH_PERMS.ALTER_CONFIG)
    async def save_guild_settings(self, message):
        guild_id, user_id = get_info(message)
        guild = self.bot.get_guild(guild_id)
        return DashConfig.update_config_section(guild, message["section"],
                                                message["modified_values"],
                                                guild.get_member(user_id))

    @needs_perm(DASH_PERMS.ALTER_CONFIG)
    async def replace_guild_settings(self, message):
        guild_id, user_id = get_info(message)
        guild = self.bot.get_guild(guild_id)
        return DashConfig.update_config_section(guild,
                                                message["section"],
                                                message["modified_values"],
                                                guild.get_member(user_id),
                                                replace=True)

    async def cache_info(self, message=None):
        return {
            'languages': Translator.LANG_NAMES,
            'logging': {
                k: list(v.keys())
                for k, v in GearbotLogging.LOGGING_INFO.items()
            }
        }

    @needs_perm(DASH_PERMS.ALTER_CONFIG)
    async def setup_mute(self, message):
        await self.override_handler(
            message, "setup", dict(send_messages=False, add_reactions=False),
            dict(speak=False, connect=False, stream=False))

    @needs_perm(DASH_PERMS.ALTER_CONFIG)
    async def cleanup_mute(self, message):
        await self.override_handler(message, "cleanup", None, None)

    async def override_handler(self, message, t, text, voice):
        guild = self.bot.get_guild(message["guild_id"])

        if not DashConfig.is_numeric(message["role_id"]):
            raise ValidationException(dict(role_id="Not a valid id"))

        role = guild.get_role(int(message["role_id"]))
        if role is None:
            raise ValidationException(dict(role_id="Not a valid id"))
        if role.id == guild.id:
            raise ValidationException(
                dict(
                    role_id="The @everyone role can't be used for muting people"
                ))
        if role.managed:
            raise ValidationException(
                dict(
                    role_id=
                    "Managed roles can not be assigned to users and thus won't work for muting people"
                ))
        user = await Utils.get_user(message["user_id"])
        parts = {
            "role_name": Utils.escape_markdown(role.name),
            "role_id": role.id,
            "user": Utils.clean_user(user),
            "user_id": user.id
        }
        GearbotLogging.log_key(guild.id, f"config_mute_{t}_triggered", **parts)
        failed = []
        for channel in guild.text_channels:
            try:
                if text is None:
                    await channel.set_permissions(role,
                                                  reason=Translator.translate(
                                                      f'mute_{t}', guild.id),
                                                  overwrite=None)
                else:
                    await channel.set_permissions(role,
                                                  reason=Translator.translate(
                                                      f'mute_{t}', guild.id),
                                                  **text)
            except Forbidden as ex:
                failed.append(channel.mention)
        for channel in guild.voice_channels:
            try:
                if voice is None:
                    await channel.set_permissions(role,
                                                  reason=Translator.translate(
                                                      f'mute_{t}', guild.id),
                                                  overwrite=None)
                else:
                    await channel.set_permissions(role,
                                                  reason=Translator.translate(
                                                      f'mute_{t}', guild.id),
                                                  **voice)
            except Forbidden as ex:
                failed.append(
                    Translator.translate('voice_channel',
                                         guild.id,
                                         channel=channel.name))

        await asyncio.sleep(
            1
        )  # delay logging so the channel overrides can get querried and logged
        GearbotLogging.log_key(guild.id, f"config_mute_{t}_complete", **parts)

        out = '\n'.join(failed)
        GearbotLogging.log_key(
            guild.id,
            f"config_mute_{t}_failed",
            **parts,
            count=len(failed),
            tag_on=None if len(failed) is 0 else f'```{out}```')

    # crowdin
    async def crowdin_webhook(self, message):
        code = message["info"]["language"]
        await Translator.update_lang(code)
        if (datetime.now() - self.last_update).seconds > 5 * 60:
            self.update_message = None
            self.to_log = dict()
        if code not in self.to_log:
            self.to_log[code] = 0
        self.to_log[code] += 1

        embed = Embed(color=Color(0x1183f6),
                      timestamp=datetime.utcfromtimestamp(time.time()),
                      description=f"**Live translation update summary!**\n" +
                      '\n'.join(f"{Translator.LANG_NAMES[code]} : {count}"
                                for code, count in self.to_log.items()))
        if self.update_message is None:
            self.update_message = await Translator.get_translator_log_channel(
            )(embed=embed)
        else:
            await self.update_message.edit(embed=embed)

        self.last_update = datetime.now()

    async def update(self, message):
        t = message["type"]
        if t == "update":
            await Update.update("whoever just pushed to master", self.bot)
        elif t == "upgrade":
            await Update.upgrade("whoever just pushed to master", self.bot)
        else:
            raise RuntimeError(
                "UNKNOWN UPDATE MESSAGE, IS SOMEONE MESSING WITH IT?")

    @commands.Cog.listener()
    async def on_guild_join(self, guild):
        for user in self.bot.dash_guild_users:
            member = guild.get_member(user)
            if member is not None:
                permission = DashUtils.get_guild_perms(member)
                if permission > 0:
                    await self.send_to_dash("guild_add",
                                            user_id=user,
                                            guilds={
                                                str(guild.id): {
                                                    "id": str(guild.id),
                                                    "name": guild.name,
                                                    "permissions": permission,
                                                    "icon": guild.icon
                                                }
                                            })

    @commands.Cog.listener()
    async def on_guild_remove(self, guild):
        for user in self.bot.dash_guild_users:
            member = guild.get_member(user)
            if member is not None:
                permission = DashUtils.get_guild_perms(member)
                if permission > 0:
                    await self.send_to_dash("guild_remove",
                                            user_id=user,
                                            guild=str(guild.id))

    @commands.Cog.listener()
    async def on_guild_update(self, before, after):
        for user in self.bot.dash_guild_users:
            member = after.get_member(user)
            if member is not None:
                old = DashUtils.get_guild_perms(member)
                new = DashUtils.get_guild_perms(member)
                if old != new:
                    await self._notify_user(member, old, new, after)
                elif before.name != after.name or before.icon != after.icon:
                    await self._notify_user(member, 0, 15, after)

    @commands.Cog.listener()
    async def on_member_update(self, before, after):
        if after.id in self.bot.dash_guild_users:
            old = DashUtils.get_guild_perms(before)
            new = DashUtils.get_guild_perms(after)
            await self._notify_user(after, old, new, before.guild)

    @commands.Cog.listener()
    async def on_guild_role_update(self, before, after):
        for user in self.bot.dash_guild_users:
            member = after.guild.get_member(user)
            if member is not None and after in member.roles:
                new = DashUtils.get_guild_perms(member)
                await self._notify_user(member, 0 if new is not 0 else 15, new,
                                        after.guild)

    @commands.Cog.listener()
    async def _notify_user(self, user, old, new, guild):
        if old != new:
            if new is not 0:
                await self.send_to_dash("guild_add",
                                        user_id=user.id,
                                        guilds={
                                            str(guild.id): {
                                                "id": str(guild.id),
                                                "name": guild.name,
                                                "permissions": new,
                                                "icon": guild.icon
                                            }
                                        })
        if new is 0 and old is not 0:
            await self.send_to_dash("guild_remove",
                                    user_id=user.id,
                                    guild=str(guild.id))
Beispiel #20
0
class RedisBackend(BroadcastBackend):
    def __init__(self, url: str):
        self.conn_url = url

        self._pub_conn: typing.Optional[aioredis.Redis] = None
        self._sub_conn: typing.Optional[aioredis.Redis] = None

        self._msg_queue: typing.Optional[asyncio.Queue] = None
        self._reader_task: typing.Optional[asyncio.Task] = None
        self._mpsc: typing.Optional[Receiver] = None

    async def connect(self) -> None:
        if self._pub_conn or self._sub_conn or self._msg_queue:
            logger.warning("connections are already setup but connect called again; not doing anything")
            return

        self._pub_conn = await aioredis.create_redis(self.conn_url)
        self._sub_conn = await aioredis.create_redis(self.conn_url)
        self._msg_queue = asyncio.Queue()  # must be created here, to get proper event loop
        self._mpsc = Receiver()
        self._reader_task = asyncio.create_task(self._reader())

    async def disconnect(self) -> None:
        if self._pub_conn and self._sub_conn:
            self._pub_conn.close()
            self._sub_conn.close()
        else:
            logger.warning("connections are not setup, invalid call to disconnect")

        self._pub_conn = None
        self._sub_conn = None
        self._msg_queue = None

        if self._mpsc:
            self._mpsc.stop()
        else:
            logger.warning("redis mpsc receiver is not set, cannot stop it")

        if self._reader_task:
            if self._reader_task.done():
                self._reader_task.result()
            else:
                logger.debug("cancelling reader task")
                self._reader_task.cancel()
                self._reader_task = None

    async def subscribe(self, channel: str) -> None:
        if not self._sub_conn:
            logger.error(f"not connected, cannot subscribe to channel {channel!r}")
            return

        await self._sub_conn.subscribe(self._mpsc.channel(channel))

    async def unsubscribe(self, channel: str) -> None:
        if not self._sub_conn:
            logger.error(f"not connected, cannot unsubscribe from channel {channel!r}")
            return

        await self._sub_conn.unsubscribe(channel)

    async def publish(self, channel: str, message: typing.Any) -> None:
        if not self._pub_conn:
            logger.error(f"not connected, cannot publish to channel {channel!r}")
            return

        await self._pub_conn.publish_json(channel, message)

    async def next_published(self) -> Event:
        if not self._msg_queue:
            raise RuntimeError("unable to get next_published event, RedisBackend is not connected")

        return await self._msg_queue.get()

    async def _reader(self) -> None:
        async for channel, msg in self._mpsc.iter(encoding="utf8", decoder=json.loads):
            if not isinstance(channel, AbcChannel):
                logger.error(f"invalid channel returned from Receiver().iter() - {channel!r}")
                continue

            channel_name = channel.name.decode("utf8")

            if not self._msg_queue:
                logger.error(f"unable to put new message from {channel_name} into queue, not connected")
                continue

            await self._msg_queue.put(Event(channel=channel_name, message=msg))
Beispiel #21
0
class DashLink(BaseCog):

    def __init__(self, bot):
        super().__init__(bot)
        bot.loop.create_task(self.init())
        self.redis_link = None
        self.receiver = Receiver(loop=bot.loop)
        self.handlers = dict(
            guild_perm_request=self.guild_perm_request
        )
        self.recieve_handlers = dict(

        )
        self.last_update = datetime.now()
        self.to_log = dict()
        self.update_message = None

        if Configuration.get_master_var("TRANSLATIONS", dict(SOURCE="SITE", CHANNEL=0, KEY= "", LOGIN="", WEBROOT=""))["SOURCE"] == 'CROWDIN':
            self.recieve_handlers["crowdin_webhook"] = self.crowdin_webhook
        self.task = self._receiver()

    def cog_unload(self):
        self.bot.loop.create_task(self._unload())

    async def _unload(self):
        for c in self.receiver.channels.values():
            self.redis_link.unsubscribe(c)
        self.receiver.stop()
        self.redis_link.close()
        await self.redis_link.wait_closed()

    async def init(self):
        try:
            self.redis_link = await aioredis.create_redis_pool(
                (Configuration.get_master_var('REDIS_HOST', "localhost"), Configuration.get_master_var('REDIS_PORT', 6379)),
                encoding="utf-8", db=0, maxsize=2) # size 2: one send, one receive
            self.bot.loop.create_task(self._receiver())
            await self.redis_link.subscribe(self.receiver.channel("dash-bot-messages"))
        except OSError:
            await GearbotLogging.bot_log("Failed to connect to the dash!")

    async def _receiver(self):
        async for sender, message in self.receiver.iter(encoding='utf-8', decoder=json.loads):
            try:
                if message["type"] in self.recieve_handlers.keys():
                    await self.recieve_handlers[message["type"]](message)
                else:
                    reply = dict(reply=await self.handlers[message["type"]](message), uid=message["uid"])
                    await self.redis_link.publish_json("bot-dash-messages", reply)
            except Exception as e:
                await TheRealGearBot.handle_exception("Dash message handling", self.bot, e, None, None, None, message)


    async def guild_perm_request(self, message):
        info = dict()
        for guid in message["guild_list"]:
            guid = int(guid)
            guild = self.bot.get_guild(guid)
            permission = 0
            if guild is not None:
                member = guild.get_member(int(message["user_id"]))
                mod_roles = Configuration.get_var(guid, "MOD_ROLES")
                if member.guild_permissions.ban_members or any(r.id in mod_roles for r in member.roles):
                    permission |= (1 << 0) # dash access
                    permission |= (1 << 1) # infraction access

                admin_roles = Configuration.get_var(guid, "ADMIN_ROLES")
                if member.guild_permissions.administrator or any(r.id in admin_roles for r in member.roles):
                    permission |= (1 << 0)  # dash access
                    permission |= (1 << 2)  # config read access
                    permission |= (1 << 3)  # config write access

            if permission > 0:
                info[guid] = dict(name=guild.name, permissions=permission, icon=guild.icon_url_as(size=256))
        return info



    #crowdin
    async def crowdin_webhook(self, message):
        code = message["info"]["language"]
        await Translator.update_lang(code)
        if (datetime.now() - self.last_update).seconds > 5*60:
            self.update_message = None
            self.to_log = dict()
        if code not in self.to_log:
            self.to_log[code] = 0
        self.to_log[code] += 1

        embed = Embed(color=Color(0x1183f6), timestamp=datetime.utcfromtimestamp(time.time()),
                      description=f"**Live translation update summary!**\n" + '\n'.join(
                          f"{Translator.LANG_NAMES[code]} : {count}" for code, count in self.to_log.items()))
        if self.update_message is None:
            self.update_message = await Translator.get_translator_log_channel()(embed=embed)
        else:
            await self.update_message.edit(embed=embed)

        self.last_update = datetime.now()
Beispiel #22
0
    async def _update_actors(self):
        # TODO(fyrestone): Refactor code for updating actor / node / job.
        # Subscribe actor channel.
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = "{}:*".format(actor_consts.ACTOR_CHANNEL)
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        def _process_actor_table_data(data):
            actor_class = actor_classname_from_task_spec(
                data.get("taskSpec", {}))
            data["actorClass"] = actor_class

        # Get all actor info.
        while True:
            try:
                logger.info("Getting all actor info from GCS.")
                request = gcs_service_pb2.GetAllActorInfoRequest()
                reply = await self._gcs_actor_info_stub.GetAllActorInfo(
                    request, timeout=5)
                if reply.status.code == 0:
                    actors = {}
                    for message in reply.actor_table_data:
                        actor_table_data = actor_table_data_to_dict(message)
                        _process_actor_table_data(actor_table_data)
                        actors[actor_table_data["actorId"]] = actor_table_data
                    # Update actors.
                    DataSource.actors.reset(actors)
                    # Update node actors and job actors.
                    job_actors = {}
                    node_actors = {}
                    for actor_id, actor_table_data in actors.items():
                        job_id = actor_table_data["jobId"]
                        node_id = actor_table_data["address"]["rayletId"]
                        job_actors.setdefault(job_id,
                                              {})[actor_id] = actor_table_data
                        # Update only when node_id is not Nil.
                        if node_id != actor_consts.NIL_NODE_ID:
                            node_actors.setdefault(
                                node_id, {})[actor_id] = actor_table_data
                    DataSource.job_actors.reset(job_actors)
                    DataSource.node_actors.reset(node_actors)
                    logger.info("Received %d actor info from GCS.",
                                len(actors))
                    break
                else:
                    raise Exception(
                        f"Failed to GetAllActorInfo: {reply.status.message}")
            except Exception:
                logger.exception("Error Getting all actor info from GCS.")
                await asyncio.sleep(
                    actor_consts.RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)

        # Receive actors from channel.
        state_keys = ("state", "address", "numRestarts", "timestamp", "pid")
        async for sender, msg in receiver.iter():
            try:
                actor_id, actor_table_data = msg
                pubsub_message = ray.gcs_utils.PubSubMessage.FromString(
                    actor_table_data)
                message = ray.gcs_utils.ActorTableData.FromString(
                    pubsub_message.data)
                actor_table_data = actor_table_data_to_dict(message)
                _process_actor_table_data(actor_table_data)
                # If actor is not new registered but updated, we only update
                # states related fields.
                if actor_table_data["state"] != "DEPENDENCIES_UNREADY":
                    actor_id = actor_id.decode(
                        "UTF-8")[len(ray.gcs_utils.TablePrefix_ACTOR_string +
                                     ":"):]
                    actor_table_data_copy = dict(DataSource.actors[actor_id])
                    for k in state_keys:
                        actor_table_data_copy[k] = actor_table_data[k]
                    actor_table_data = actor_table_data_copy
                actor_id = actor_table_data["actorId"]
                job_id = actor_table_data["jobId"]
                node_id = actor_table_data["address"]["rayletId"]
                # Update actors.
                DataSource.actors[actor_id] = actor_table_data
                # Update node actors (only when node_id is not Nil).
                if node_id != actor_consts.NIL_NODE_ID:
                    node_actors = dict(DataSource.node_actors.get(node_id, {}))
                    node_actors[actor_id] = actor_table_data
                    DataSource.node_actors[node_id] = node_actors
                # Update job actors.
                job_actors = dict(DataSource.job_actors.get(job_id, {}))
                job_actors[actor_id] = actor_table_data
                DataSource.job_actors[job_id] = job_actors
            except Exception:
                logger.exception("Error receiving actor info.")
Beispiel #23
0
class Controller:
    def __init__(self):
        self.receiver = Receiver()
        self.marks = []

        self.last_average_payload = 0

    async def reader(self):
        await cache.connection.subscribe(self.receiver.channel('marks'))

        async for channel, message in self.receiver.iter():
            mark = json.loads(message)
            self.marks.append(mark)

    async def control(self):
        status_data = self.get_status()
        if status_data is None:
            return

        status, average_payload, aggregated_marks_number = status_data
        status_datetime = datetime.datetime.now()

        try:
            self.send(status, status_datetime)
        except socket.error:
            return

        await self.update(status, average_payload, status_datetime,
                          aggregated_marks_number)

    def get_status(self):

        if not self.marks:
            return None

        aggregated_marks_number = len(self.marks)

        average_payload = sum(mark['payload'] for mark in islice(
            self.marks, aggregated_marks_number)) / aggregated_marks_number

        status = 'down' if average_payload < self.last_average_payload else 'up'

        return status, average_payload, aggregated_marks_number

    async def update(self, status: str, average_payload: float,
                     status_datetime: datetime.datetime,
                     aggregated_marks_number: int):

        self.last_average_payload = average_payload
        self.marks = self.marks[aggregated_marks_number:]

        with context_session() as db:
            command = models.Command(status=status, datetime=status_datetime)
            db.add(command)
            db.commit()

        await cache.connection.set(last_command_datetime_key,
                                   status_datetime.isoformat())

    def send(self, status: str, datetime: datetime.datetime):
        payload = {
            'datetime': datetime,
            'status': status,
        }
        encoded_payload = json.dumps(payload,
                                     cls=DateTimeEncoder).encode('utf-8')

        conn = socket.create_connection(
            (settings.global_settings.manipulator_host,
             settings.global_settings.manipulator_port))
        conn.sendall(encoded_payload)
        conn.close()
Beispiel #24
0
class Broadcast:
    def __init__(self, connection_url: str) -> None:
        self.connection_url = connection_url
        self._pub_conn: Optional[aioredis.Redis] = None
        self._sub_conn: Optional[aioredis.Redis] = None
        self._receiver: Optional[Receiver] = None

    async def connect(
        self, *, conn_retries: int = 5, conn_retry_delay: int = 1, retry: int = 0
    ) -> "Broadcast":
        try:
            self._pub_conn = await aioredis.create_redis(self.connection_url)
            self._sub_conn = await aioredis.create_redis(self.connection_url)
            self._receiver = Receiver()
        except (ConnectionError, aioredis.RedisError, asyncio.TimeoutError) as e:
            if retry < conn_retries:
                logger.warning(
                    "Redis connection error %s %s %s, %d retries remaining...",
                    self.connection_url,
                    e.__class__.__name__,
                    e,
                    conn_retries - retry,
                )
                await asyncio.sleep(conn_retry_delay)
            else:
                logger.error("Connecting to Redis failed")
                raise
        else:
            if retry > 0:
                logger.info("Redis connection successful")
            return self

        return await self.connect(
            conn_retries=conn_retries,
            conn_retry_delay=conn_retry_delay,
            retry=retry + 1,
        )

    async def publish(
        self, channel: str, message: dict, retry_count: Optional[int] = None
    ) -> None:
        if self.has_pub_conn():
            await self._pub_conn.publish_json(channel=channel, obj=message)
        elif not self.has_pub_conn() and retry_count:
            try:
                await self.connect(conn_retries=retry_count)
            except ConnectionRefusedError:
                logger.warning(
                    f"Retry failed - cannot publish message to channel {channel}"
                )
                return None
        else:
            return None

    async def unsubscribe(self, channel: str) -> None:
        if self.has_sub_conn():
            await self._sub_conn.unsubscribe(channel)

    async def disconnect(self) -> None:
        self._pub_conn.close()
        self._sub_conn.close()

        await self._pub_conn.wait_closed()
        await self._sub_conn.wait_closed()

    @asynccontextmanager
    async def subscribe(
        self, channel: str, retry_count: Optional[int] = None
    ) -> AsyncIterator:
        if not self.has_sub_conn():
            if retry_count:
                try:
                    await self.connect(conn_retries=1)
                except ConnectionRefusedError:
                    logger.warning(
                        f"Retry failed; cannot subscribe to channel {channel}."
                    )
                    yield None
                    return
            else:
                yield None
                return

        await self._sub_conn.subscribe(self._receiver.channel(channel))

        yield self._receiver.iter(encoding="utf8", decoder=json.loads)

        await self.unsubscribe(channel)

    def has_sub_conn(self) -> bool:
        return not (self._sub_conn is None or self._sub_conn.closed)

    def has_pub_conn(self) -> bool:
        return not (self._pub_conn is None or self._pub_conn.closed)
Beispiel #25
0
    async def _update_actors(self):
        # Subscribe actor channel.
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = "{}:*".format(stats_collector_consts.ACTOR_CHANNEL)
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        def _process_actor_table_data(data):
            actor_class = actor_classname_from_task_spec(
                data.get("taskSpec", {}))
            data["actorClass"] = actor_class

        # Get all actor info.
        while True:
            try:
                logger.info("Getting all actor info from GCS.")
                request = gcs_service_pb2.GetAllActorInfoRequest()
                reply = await self._gcs_actor_info_stub.GetAllActorInfo(
                    request, timeout=5)
                if reply.status.code == 0:
                    actors = {}
                    for message in reply.actor_table_data:
                        actor_table_data = actor_table_data_to_dict(message)
                        _process_actor_table_data(actor_table_data)
                        actors[actor_table_data["actorId"]] = actor_table_data
                    # Update actors.
                    DataSource.actors.reset(actors)
                    # Update node actors and job actors.
                    job_actors = {}
                    node_actors = {}
                    for actor_id, actor_table_data in actors.items():
                        job_id = actor_table_data["jobId"]
                        node_id = actor_table_data["address"]["rayletId"]
                        job_actors.setdefault(job_id,
                                              {})[actor_id] = actor_table_data
                        node_actors.setdefault(node_id,
                                               {})[actor_id] = actor_table_data
                    DataSource.job_actors.reset(job_actors)
                    DataSource.node_actors.reset(node_actors)
                    logger.info("Received %d actor info from GCS.",
                                len(actors))
                    break
                else:
                    raise Exception(
                        f"Failed to GetAllActorInfo: {reply.status.message}")
            except Exception:
                logger.exception("Error Getting all actor info from GCS.")
                await asyncio.sleep(stats_collector_consts.
                                    RETRY_GET_ALL_ACTOR_INFO_INTERVAL_SECONDS)

        # Receive actors from channel.
        async for sender, msg in receiver.iter():
            try:
                _, actor_table_data = msg
                pubsub_message = ray.gcs_utils.PubSubMessage.FromString(
                    actor_table_data)
                message = ray.gcs_utils.ActorTableData.FromString(
                    pubsub_message.data)
                actor_table_data = actor_table_data_to_dict(message)
                _process_actor_table_data(actor_table_data)
                actor_id = actor_table_data["actorId"]
                job_id = actor_table_data["jobId"]
                node_id = actor_table_data["address"]["rayletId"]
                # Update actors.
                DataSource.actors[actor_id] = actor_table_data
                # Update node actors.
                node_actors = dict(DataSource.node_actors.get(node_id, {}))
                node_actors[actor_id] = actor_table_data
                DataSource.node_actors[node_id] = node_actors
                # Update job actors.
                job_actors = dict(DataSource.job_actors.get(job_id, {}))
                job_actors[actor_id] = actor_table_data
                DataSource.job_actors[job_id] = job_actors
            except Exception:
                logger.exception("Error receiving actor info.")
Beispiel #26
0
class DashLink:
    def __init__(self, bot) -> None:
        self.bot: GearBot = bot
        bot.loop.create_task(self.init())
        self.redis_link = None
        self.receiver = Receiver(loop=bot.loop)
        self.handlers = dict(guild_perm_request=self.guild_perm_request)
        self.task = self._receiver()

    def __unload(self):
        self.bot.loop.create_task(self._unload())

    async def _unload(self):
        for c in self.receiver.channels.values():
            self.redis_link.unsubscribe(c)
        self.receiver.stop()
        self.redis_link.close()
        await self.redis_link.wait_closed()

    async def init(self):
        try:
            self.redis_link = await aioredis.create_redis_pool(
                (Configuration.get_master_var('REDIS_HOST', "localhost"),
                 Configuration.get_master_var('REDIS_PORT', 6379)),
                encoding="utf-8",
                db=0,
                maxsize=2)  # size 2: one send, one receive
            self.bot.loop.create_task(self._receiver())
            await self.redis_link.subscribe(
                self.receiver.channel("dash-bot-messages"))
        except OSError:
            await GearbotLogging.bot_log("Failed to connect to the dash!")

    async def _receiver(self):
        async for sender, message in self.receiver.iter(encoding='utf-8',
                                                        decoder=json.loads):
            try:
                reply = dict(reply=await self.handlers[message["type"]]
                             (message),
                             uid=message["uid"])
                await self.redis_link.publish_json("bot-dash-messages", reply)
            except Exception as e:
                await TheRealGearBot.handle_exception("Dash message handling",
                                                      self.bot, e, None, None,
                                                      None, message)

    async def guild_perm_request(self, message):
        info = dict()
        for guid in message["guild_list"]:
            guid = int(guid)
            guild = self.bot.get_guild(guid)
            permission = 0
            if guild is not None:
                member = guild.get_member(int(message["user_id"]))
                mod_roles = Configuration.get_var(guid, "MOD_ROLES")
                if member.guild_permissions.ban_members or any(
                        r.id in mod_roles for r in member.roles):
                    permission |= (1 << 0)  # dash access
                    permission |= (1 << 1)  # infraction access

                admin_roles = Configuration.get_var(guid, "ADMIN_ROLES")
                if member.guild_permissions.administrator or any(
                        r.id in admin_roles for r in member.roles):
                    permission |= (1 << 0)  # dash access
                    permission |= (1 << 2)  # config read access
                    permission |= (1 << 3)  # config write access

            if permission > 0:
                info[guid] = dict(name=guild.name,
                                  permissions=permission,
                                  icon=guild.icon_url_as(size=256))
        return info