예제 #1
0
async def test_pubsub_receiver_iter(create_redis, server, loop):
    sub = await create_redis(server.tcp_address, loop=loop)
    pub = await create_redis(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)

    async def coro(mpsc):
        lst = []
        async for msg in mpsc.iter():
            lst.append(msg)
        return lst

    tsk = asyncio.ensure_future(coro(mpsc), loop=loop)
    snd1, = await sub.subscribe(mpsc.channel('chan:1'))
    snd2, = await sub.subscribe(mpsc.channel('chan:2'))
    snd3, = await sub.psubscribe(mpsc.pattern('chan:*'))

    subscribers = await pub.publish_json('chan:1', {'Hello': 'World'})
    assert subscribers > 1
    subscribers = await pub.publish_json('chan:2', ['message'])
    assert subscribers > 1
    loop.call_later(0, mpsc.stop)
    # await asyncio.sleep(0, loop=loop)
    assert await tsk == [
        (snd1, b'{"Hello": "World"}'),
        (snd3, (b'chan:1', b'{"Hello": "World"}')),
        (snd2, b'["message"]'),
        (snd3, (b'chan:2', b'["message"]')),
        ]
    assert not mpsc.is_active
예제 #2
0
class Messager:
    def __init__(self, inbound, outbound, loop):
        self.loop = loop
        self.conn = None
        self.inbound = inbound
        self.outbound = outbound
        self.receiver = None
        self.task = None
        self.replies = dict()
        self.expected = set()

    async def initialize(self):
        self.conn = await aioredis.create_redis_pool(("localhost", 6379),
                                                     encoding="utf-8",
                                                     maxsize=2)
        self.receiver = Receiver(loop=self.loop)
        await self.conn.subscribe(self.receiver.channel(self.inbound))
        self.task = self.loop.create_task(self.fetcher())

        print(
            f'Redis connection established, listening on {self.inbound}, sending on {self.outbound}'
        )  # FIXME - propper logging

    async def terminate(self):
        # terminate channels and disconnect from redis
        self.conn.unsubscribe(self.inbound)
        self.task.cancel()
        self.receiver.stop()
        self.conn.close()
        await self.conn.wait_closed()

    async def fetcher(self):
        async for sender, message in self.receiver.iter(encoding='utf-8',
                                                        decoder=json.loads):
            channel = sender.name.decode()
            if channel == self.inbound:
                uid = message["uid"]
                if uid not in self.expected:
                    print("Unexpected message!")
                    print(message)
                else:
                    self.expected.remove(uid)
                    self.replies[uid] = message

    async def get_reply(self, data):
        try:
            return (await asyncio.wait_for(self._get_reply(data), 10))["reply"]
        except TimeoutError:
            raise Redisception("No reply received from the bot!")

    async def _get_reply(self, data):
        uid = str(uuid.uuid4())
        self.expected.add(uid)
        data["uid"] = uid
        await self.conn.publish_json(self.outbound, data)
        while uid not in self.replies:
            await asyncio.sleep(0.1)
        reply = self.replies[uid]
        del self.replies[uid]
        return reply
예제 #3
0
async def test_pubsub_receiver_iter(create_redis, server, event_loop):
    sub = await create_redis(server.tcp_address)
    pub = await create_redis(server.tcp_address)

    mpsc = Receiver()

    async def coro(mpsc):
        lst = []
        async for msg in mpsc.iter():
            lst.append(msg)
        return lst

    tsk = asyncio.ensure_future(coro(mpsc))
    (snd1, ) = await sub.subscribe(mpsc.channel("chan:1"))
    (snd2, ) = await sub.subscribe(mpsc.channel("chan:2"))
    (snd3, ) = await sub.psubscribe(mpsc.pattern("chan:*"))

    subscribers = await pub.publish_json("chan:1", {"Hello": "World"})
    assert subscribers > 1
    subscribers = await pub.publish_json("chan:2", ["message"])
    assert subscribers > 1
    event_loop.call_later(0, mpsc.stop)
    await asyncio.sleep(0.01)
    assert await tsk == [
        (snd1, b'{"Hello": "World"}'),
        (snd3, (b"chan:1", b'{"Hello": "World"}')),
        (snd2, b'["message"]'),
        (snd3, (b"chan:2", b'["message"]')),
    ]
    assert not mpsc.is_active
예제 #4
0
async def test_stopped(create_connection, server, caplog):
    sub = await create_connection(server.tcp_address)
    pub = await create_connection(server.tcp_address)

    mpsc = Receiver()
    await sub.execute_pubsub("subscribe", mpsc.channel("channel:1"))
    assert mpsc.is_active
    mpsc.stop()

    caplog.clear()
    with caplog.at_level("DEBUG", "aioredis"):
        await pub.execute("publish", "channel:1", b"Hello")
        await asyncio.sleep(0)

    assert len(caplog.record_tuples) == 1
    # Receiver must have 1 EndOfStream message
    message = (
        "Pub/Sub listener message after stop: "
        "sender: <_Sender name:b'channel:1', is_pattern:False, receiver:"
        "<Receiver is_active:False, senders:1, qsize:0>>, data: b'Hello'")
    assert caplog.record_tuples == [
        ("aioredis", logging.WARNING, message),
    ]

    # assert (await mpsc.get()) is None
    with pytest.raises(ChannelClosedError):
        await mpsc.get()
    res = await mpsc.wait_message()
    assert res is False
예제 #5
0
    def __init__(self, bot):
        super().__init__(bot)
        bot.loop.create_task(self.init())
        self.redis_link: aioredis.Redis = None
        self.receiver = Receiver(loop=bot.loop)
        self.handlers = dict(question=self.question,
                             update=self.update,
                             user_guilds=self.user_guilds,
                             user_guilds_end=self.user_guilds_end,
                             guild_info_watch=self.guild_info_watch,
                             guild_info_watch_end=self.guild_info_watch_end)
        self.question_handlers = dict(
            heartbeat=self.still_spinning,
            user_info=self.user_info_request,
            get_guild_settings=self.get_guild_settings,
            save_guild_settings=self.save_guild_settings,
            replace_guild_settings=self.replace_guild_settings,
            setup_mute=self.setup_mute,
            cleanup_mute=self.cleanup_mute,
            cache_info=self.cache_info,
            guild_user_perms=self.guild_user_perms)
        # The last time we received a heartbeat, the current attempt number, how many times we have notified the owner
        self.last_dash_heartbeat = [time.time(), 0, 0]
        self.last_update = datetime.now()
        self.to_log = dict()
        self.update_message = None

        if Configuration.get_master_var(
                "TRANSLATIONS",
                dict(SOURCE="SITE", CHANNEL=0, KEY="", LOGIN="",
                     WEBROOT=""))["SOURCE"] == 'CROWDIN':
            self.handlers["crowdin_webhook"] = self.crowdin_webhook
        self.task = self._receiver()
예제 #6
0
    async def _update_error_info(self):
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = ray.gcs_utils.RAY_ERROR_PUBSUB_PATTERN
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        async for sender, msg in receiver.iter():
            try:
                _, data = msg
                pubsub_msg = ray.gcs_utils.PubSubMessage.FromString(data)
                error_data = ray.gcs_utils.ErrorTableData.FromString(
                    pubsub_msg.data)
                message = error_data.error_message
                message = re.sub(r"\x1b\[\d+m", "", message)
                match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
                if match:
                    pid = match.group(1)
                    ip = match.group(2)
                    errs_for_ip = DataSource.ip_and_pid_to_errors.get(ip, {})
                    pid_errors = errs_for_ip.get(pid, [])
                    pid_errors.append({
                        "message": message,
                        "timestamp": error_data.timestamp,
                        "type": error_data.type
                    })
                    errs_for_ip[pid] = pid_errors
                    DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
                    logger.info(f"Received error entry for {ip} {pid}")
            except Exception:
                logger.exception("Error receiving error info.")
예제 #7
0
파일: node_head.py 프로젝트: mvindiola1/ray
    async def _update_log_info(self):
        def process_log_batch(log_batch):
            ip = log_batch["ip"]
            pid = str(log_batch["pid"])
            if pid != "autoscaler":
                logs_for_ip = dict(DataSource.ip_and_pid_to_logs.get(ip, {}))
                logs_for_pid = list(logs_for_ip.get(pid, []))
                logs_for_pid.extend(log_batch["lines"])

                # Only cache upto MAX_LOGS_TO_CACHE
                logs_length = len(logs_for_pid)
                if logs_length > MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD:
                    offset = logs_length - MAX_LOGS_TO_CACHE
                    del logs_for_pid[:offset]

                logs_for_ip[pid] = logs_for_pid
                DataSource.ip_and_pid_to_logs[ip] = logs_for_ip
            logger.info(f"Received a log for {ip} and {pid}")

        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        channel = receiver.channel(gcs_utils.LOG_FILE_CHANNEL)
        await aioredis_client.subscribe(channel)
        logger.info("Subscribed to %s", channel)

        async for sender, msg in receiver.iter():
            try:
                data = json.loads(ray._private.utils.decode(msg))
                data["pid"] = str(data["pid"])
                process_log_batch(data)
            except Exception:
                logger.exception("Error receiving log from Redis.")
예제 #8
0
async def test_stopped(create_connection, server, loop):
    sub = await create_connection(server.tcp_address, loop=loop)
    pub = await create_connection(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)
    await sub.execute_pubsub('subscribe', mpsc.channel('channel:1'))
    assert mpsc.is_active
    mpsc.stop()

    with logs('aioredis', 'DEBUG') as cm:
        await pub.execute('publish', 'channel:1', b'Hello')
        await asyncio.sleep(0, loop=loop)

    assert len(cm.output) == 1
    # Receiver must have 1 EndOfStream message
    warn_messaege = (
        "WARNING:aioredis:Pub/Sub listener message after stop: "
        "sender: <_Sender name:b'channel:1', is_pattern:False, receiver:"
        "<Receiver is_active:False, senders:1, qsize:0>>, data: b'Hello'"
    )
    assert cm.output == [warn_messaege]

    # assert (await mpsc.get()) is None
    with pytest.raises(ChannelClosedError):
        await mpsc.get()
    res = await mpsc.wait_message()
    assert res is False
예제 #9
0
async def redis_relay(websocket):
    conn = await aioredis.create_connection(("localhost", 6379))
    receiver = Receiver()
    conn.execute_pubsub("subscribe", receiver.channel("marketupdates"))
    while await receiver.wait_message():
        *_, message = await receiver.get()
        await websocket.send(message.decode())
예제 #10
0
async def test_stopped(create_connection, server, loop):
    sub = await create_connection(server.tcp_address, loop=loop)
    pub = await create_connection(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)
    await sub.execute_pubsub('subscribe', mpsc.channel('channel:1'))
    assert mpsc.is_active
    mpsc.stop()

    with logs('aioredis', 'DEBUG') as cm:
        await pub.execute('publish', 'channel:1', b'Hello')
        await asyncio.sleep(0, loop=loop)

    assert len(cm.output) == 1
    # Receiver must have 1 EndOfStream message
    warn_messaege = (
        "WARNING:aioredis:Pub/Sub listener message after stop: "
        "sender: <_Sender name:b'channel:1', is_pattern:False, receiver:"
        "<Receiver is_active:False, senders:1, qsize:0>>, data: b'Hello'")
    assert cm.output == [warn_messaege]

    # assert (await mpsc.get()) is None
    with pytest.raises(ChannelClosedError):
        await mpsc.get()
    res = await mpsc.wait_message()
    assert res is False
예제 #11
0
async def test_pubsub_receiver_iter(create_redis, server, loop):
    sub = await create_redis(server.tcp_address, loop=loop)
    pub = await create_redis(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)

    async def coro(mpsc):
        lst = []
        async for msg in mpsc.iter():
            lst.append(msg)
        return lst

    tsk = asyncio.ensure_future(coro(mpsc), loop=loop)
    snd1, = await sub.subscribe(mpsc.channel('chan:1'))
    snd2, = await sub.subscribe(mpsc.channel('chan:2'))
    snd3, = await sub.psubscribe(mpsc.pattern('chan:*'))

    subscribers = await pub.publish_json('chan:1', {'Hello': 'World'})
    assert subscribers > 1
    subscribers = await pub.publish_json('chan:2', ['message'])
    assert subscribers > 1
    loop.call_later(0, mpsc.stop)
    # await asyncio.sleep(0, loop=loop)
    assert await tsk == [
        (snd1, b'{"Hello": "World"}'),
        (snd3, (b'chan:1', b'{"Hello": "World"}')),
        (snd2, b'["message"]'),
        (snd3, (b'chan:2', b'["message"]')),
    ]
    assert not mpsc.is_active
예제 #12
0
    async def connect(
        self, *, conn_retries: int = 5, conn_retry_delay: int = 1, retry: int = 0
    ) -> "Broadcast":
        try:
            self._pub_conn = await aioredis.create_redis(self.connection_url)
            self._sub_conn = await aioredis.create_redis(self.connection_url)
            self._receiver = Receiver()
        except (ConnectionError, aioredis.RedisError, asyncio.TimeoutError) as e:
            if retry < conn_retries:
                logger.warning(
                    "Redis connection error %s %s %s, %d retries remaining...",
                    self.connection_url,
                    e.__class__.__name__,
                    e,
                    conn_retries - retry,
                )
                await asyncio.sleep(conn_retry_delay)
            else:
                logger.error("Connecting to Redis failed")
                raise
        else:
            if retry > 0:
                logger.info("Redis connection successful")
            return self

        return await self.connect(
            conn_retries=conn_retries,
            conn_retry_delay=conn_retry_delay,
            retry=retry + 1,
        )
예제 #13
0
 def __init__(self, bot) -> None:
     self.bot: GearBot = bot
     bot.loop.create_task(self.init())
     self.redis_link = None
     self.receiver = Receiver(loop=bot.loop)
     self.handlers = dict(guild_perm_request=self.guild_perm_request)
     self.task = self._receiver()
예제 #14
0
파일: app.py 프로젝트: Dermogod/ForcAD
async def before_server_start(_sanic, loop):
    mpsc = Receiver(loop=loop)
    redis = await storage.get_async_redis_pool(loop)
    await redis.subscribe(
        mpsc.channel('scoreboard'),
        mpsc.channel('stolen_flags'),
    )
    sio.start_background_task(background_task, mpsc)
예제 #15
0
 async def receive(self):
     print('Receive from redis')
     sub = await aioredis.create_connection(('localhost', 6379))
     receiver = Receiver()
     sub.execute_pubsub('subscribe', receiver.channel(self.channel))
     while (await receiver.wait_message()):
         msg = await receiver.get()
         print("Got Message:", msg)
예제 #16
0
 async def run(self, loop, **kwargs):
     aredis = await create_aredis()
     mpsc = Receiver()
     pattern = mpsc.pattern("TICK:*")
     await aredis.psubscribe(pattern)
     logger.info("subscribed")
     async for channel, data in mpsc.iter():
         data = ujson.loads(data[1])
         self.handle_tick(data)
예제 #17
0
 async def subscribe(self):
     if not self._subscribe:
         redis = await self.redis
         # self._subscribe = await redis.subscribe('ws')
         r = Receiver(loop=self.loop)
         for ch in self._channels:
             await redis.subscribe(r.channel(ch))
         self._subscribe = r
     return self._subscribe
예제 #18
0
파일: redis.py 프로젝트: scorp249/ipapp
 async def _connect_subscr(self) -> None:
     async with self.redis_connect_subscr:
         self.redis_subscr = await create_redis(
             self.cfg.url,
             encoding=self.encoding,
         )
         self.mpsc = Receiver(loop=app.loop)
         await self.redis_subscr.subscribe(
             self.mpsc.channel(self.cfg.channel))
         self._reader_fut = asyncio.ensure_future(self._reader(self.mpsc))
예제 #19
0
async def subscribe_to_channel(loop, redis_pool):
    mpsc = Receiver(loop=loop)

    logger.info("Aquiring redis connection from pool")
    connection = await redis_pool.acquire()

    logger.info("Subscribing to redis channel: {}".format(websocket_channel))
    await connection.execute_pubsub("subscribe",
                                    mpsc.channel(websocket_channel))
    return mpsc
예제 #20
0
async def redis_listener():
    redis = await aioredis.create_redis('redis://redishost:6379')
    receiver = Receiver()
    await redis.subscribe(receiver.channel('sockets:notification:message'),
                          receiver.channel('sockets:notification:message2'))

    while await receiver.wait_message():
        sender, msg = await receiver.get()
        for sid in users:
            await sio.emit('notification:message', {'data': str(msg)}, room=sid)
예제 #21
0
    async def connect(self) -> None:
        if self._pub_conn or self._sub_conn or self._msg_queue:
            logger.warning("connections are already setup but connect called again; not doing anything")
            return

        self._pub_conn = await aioredis.create_redis(self.conn_url)
        self._sub_conn = await aioredis.create_redis(self.conn_url)
        self._msg_queue = asyncio.Queue()  # must be created here, to get proper event loop
        self._mpsc = Receiver()
        self._reader_task = asyncio.create_task(self._reader())
예제 #22
0
    async def initialize(self):
        self.conn = await aioredis.create_redis_pool(("localhost", 6379),
                                                     encoding="utf-8",
                                                     maxsize=2)
        self.receiver = Receiver(loop=self.loop)
        await self.conn.subscribe(self.receiver.channel(self.inbound))
        self.task = self.loop.create_task(self.fetcher())

        print(
            f'Redis connection established, listening on {self.inbound}, sending on {self.outbound}'
        )  # FIXME - propper logging
예제 #23
0
파일: event.py 프로젝트: Tylerlog/hagworm
    async def _event_listener(self, channel):

        async for _ in AsyncCirculator():

            async with self._redis_pool.get_client() as cache:

                receiver = Receiver()

                await cache.subscribe(receiver.channel(channel))

                async for channel, message in receiver.iter():
                    await self._event_assigner(channel, message)
예제 #24
0
    async def _update_jobs(self):
        # Subscribe job channel.
        aioredis_client = self._dashboard_head.aioredis_client
        receiver = Receiver()

        key = f"{job_consts.JOB_CHANNEL}:*"
        pattern = receiver.pattern(key)
        await aioredis_client.psubscribe(pattern)
        logger.info("Subscribed to %s", key)

        # Get all job info.
        while True:
            try:
                logger.info("Getting all job info from GCS.")
                request = gcs_service_pb2.GetAllJobInfoRequest()
                reply = await self._gcs_job_info_stub.GetAllJobInfo(request,
                                                                    timeout=5)
                if reply.status.code == 0:
                    jobs = {}
                    for job_table_data in reply.job_info_list:
                        data = job_table_data_to_dict(job_table_data)
                        jobs[data["jobId"]] = data
                    # Update jobs.
                    DataSource.jobs.reset(jobs)
                    logger.info("Received %d job info from GCS.", len(jobs))
                    break
                else:
                    raise Exception(
                        f"Failed to GetAllJobInfo: {reply.status.message}")
            except Exception:
                logger.exception("Error Getting all job info from GCS.")
                await asyncio.sleep(
                    job_consts.RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS)

        # Receive jobs from channel.
        async for sender, msg in receiver.iter():
            try:
                _, data = msg
                pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data)
                message = ray.gcs_utils.JobTableData.FromString(
                    pubsub_message.data)
                job_id = ray._raylet.JobID(message.job_id)
                if job_id.is_submitted_from_dashboard():
                    job_table_data = job_table_data_to_dict(message)
                    job_id = job_table_data["jobId"]
                    # Update jobs.
                    DataSource.jobs[job_id] = job_table_data
                else:
                    logger.info(
                        "Ignore job %s which is not submitted from dashboard.",
                        job_id.hex())
            except Exception:
                logger.exception("Error receiving job info.")
예제 #25
0
async def main():
    await redis.connect()
    mpsc = Receiver(loop=asyncio.get_event_loop())
    asyncio.ensure_future(reader(mpsc))
    await redis._redis.subscribe(mpsc.channel('channel:1'))
    while True:
        try:
            await asyncio.sleep(10)
            print('hearbeat', flush=True)
        except Exception as ex:
            print(ex, flush=True)
            mpsc.stop()
            await redis.disconnect()
예제 #26
0
    async def subscribe_chat(self):
        redis_for_pubsub = await aioredis.create_redis(self.redis.address, db=self.redis.db)
        receiver = Receiver()

        channel = receiver.channel(CHAT_CHANNEL)
        await redis_for_pubsub.subscribe(channel)

        async def _get():
            return (await receiver.get())[1]

        yield channel, _get

        await redis_for_pubsub.unsubscribe(CHAT_CHANNEL)
예제 #27
0
    async def _update_error_info(self):
        def process_error(error_data):
            message = error_data.error_message
            message = re.sub(r"\x1b\[\d+m", "", message)
            match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
            if match:
                pid = match.group(1)
                ip = match.group(2)
                errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {}))
                pid_errors = list(errs_for_ip.get(pid, []))
                pid_errors.append(
                    {
                        "message": message,
                        "timestamp": error_data.timestamp,
                        "type": error_data.type,
                    }
                )
                errs_for_ip[pid] = pid_errors
                DataSource.ip_and_pid_to_errors[ip] = errs_for_ip
                logger.info(f"Received error entry for {ip} {pid}")

        if self._dashboard_head.gcs_error_subscriber:
            while True:
                try:
                    (
                        _,
                        error_data,
                    ) = await self._dashboard_head.gcs_error_subscriber.poll()
                    if error_data is None:
                        continue
                    process_error(error_data)
                except Exception:
                    logger.exception("Error receiving error info from GCS.")
        else:
            aioredis_client = self._dashboard_head.aioredis_client
            receiver = Receiver()

            key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN
            pattern = receiver.pattern(key)
            await aioredis_client.psubscribe(pattern)
            logger.info("Subscribed to %s", key)

            async for _, msg in receiver.iter():
                try:
                    _, data = msg
                    pubsub_msg = gcs_utils.PubSubMessage.FromString(data)
                    error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
                    process_error(error_data)
                except Exception:
                    logger.exception("Error receiving error info from Redis.")
예제 #28
0
async def test_decode_message_error(loop):
    mpsc = Receiver(loop)
    ch = mpsc.channel('channel:1')

    ch.put_nowait(b'{"hello": "world"}')
    unexpected = (mock.ANY, {'hello': 'world'})
    with pytest.raises(TypeError):
        assert (await mpsc.get(decoder=json.loads)) == unexpected

    ch = mpsc.pattern('*')
    ch.put_nowait((b'channel', b'{"hello": "world"}'))
    unexpected = (mock.ANY, b'channel', {'hello': 'world'})
    with pytest.raises(TypeError):
        assert (await mpsc.get(decoder=json.loads)) == unexpected
예제 #29
0
async def test_decode_message_error():
    mpsc = Receiver()
    ch = mpsc.channel("channel:1")

    ch.put_nowait(b'{"hello": "world"}')
    unexpected = (mock.ANY, {"hello": "world"})
    with pytest.raises(TypeError):
        assert (await mpsc.get(decoder=json.loads)) == unexpected

    ch = mpsc.pattern("*")
    ch.put_nowait((b"channel", b'{"hello": "world"}'))
    unexpected = (mock.ANY, b"channel", {"hello": "world"})
    with pytest.raises(TypeError):
        assert (await mpsc.get(decoder=json.loads)) == unexpected
예제 #30
0
async def test_decode_message_error(loop):
    mpsc = Receiver(loop)
    ch = mpsc.channel('channel:1')

    ch.put_nowait(b'{"hello": "world"}')
    unexpected = (mock.ANY, {'hello': 'world'})
    with pytest.raises(TypeError):
        assert (await mpsc.get(decoder=json.loads)) == unexpected

    ch = mpsc.pattern('*')
    ch.put_nowait((b'channel', b'{"hello": "world"}'))
    unexpected = (mock.ANY, b'channel', {'hello': 'world'})
    with pytest.raises(TypeError):
        assert (await mpsc.get(decoder=json.loads)) == unexpected
예제 #31
0
    async def run(self):
        p = self._dashboard_head.aioredis_client
        mpsc = Receiver()

        reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
        await p.psubscribe(mpsc.pattern(reporter_key))
        logger.info("Subscribed to {}".format(reporter_key))

        async for sender, msg in mpsc.iter():
            try:
                _, data = msg
                data = json.loads(ray.utils.decode(data))
                DataSource.node_physical_stats[data["ip"]] = data
            except Exception as ex:
                logger.exception(ex)
예제 #32
0
async def test_pubsub_receiver_call_stop_with_empty_queue(
        create_redis, server, loop):
    sub = await create_redis(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)

    # FIXME: currently at least one subscriber is needed
    snd1, = await sub.subscribe(mpsc.channel('chan:1'))

    now = loop.time()
    loop.call_later(.5, mpsc.stop)
    async for i in mpsc.iter():  # noqa (flake8 bug with async for)
        assert False, "StopAsyncIteration not raised"
    dt = loop.time() - now
    assert dt <= 1.5
    assert not mpsc.is_active
예제 #33
0
async def test_pubsub_receiver_call_stop_with_empty_queue(
        create_redis, server, loop):
    sub = await create_redis(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)

    # FIXME: currently at least one subscriber is needed
    snd1, = await sub.subscribe(mpsc.channel('chan:1'))

    now = loop.time()
    loop.call_later(.5, mpsc.stop)
    async for i in mpsc.iter():  # noqa (flake8 bug with async for)
        assert False, "StopAsyncIteration not raised"
    dt = loop.time() - now
    assert dt <= 1.5
    assert not mpsc.is_active
예제 #34
0
 async def listen_redis(self, websocket, channels):
     channel = channels.pop(0)
     try:
         mpsc = Receiver(loop=asyncio.get_event_loop())
         await self.sub.subscribe(
             mpsc.channel(channel),
             *(mpsc.channel(channel) for channel in channels))
         async for channel, msg in mpsc.iter():
             if websocket.client_state == WebSocketState.CONNECTED:
                 await websocket.send_bytes(msg)
     except:
         import traceback
         traceback.print_exc()
         await self.close_connections()
     finally:
         logger.info('Connection closed')
예제 #35
0
async def test_wait_message(create_connection, server, loop):
    sub = await create_connection(server.tcp_address, loop=loop)
    pub = await create_connection(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)
    await sub.execute_pubsub('subscribe', mpsc.channel('channel:1'))
    fut = asyncio.ensure_future(mpsc.wait_message(), loop=loop)
    assert not fut.done()
    await asyncio.sleep(0, loop=loop)
    assert not fut.done()

    await pub.execute('publish', 'channel:1', 'hello')
    await asyncio.sleep(0, loop=loop)  # read in connection
    await asyncio.sleep(0, loop=loop)  # call Future.set_result
    assert fut.done()
    res = await fut
    assert res is True
예제 #36
0
async def test_decode_message(loop):
    mpsc = Receiver(loop)
    ch = mpsc.channel('channel:1')
    ch.put_nowait(b'Some data')

    res = await mpsc.get(encoding='utf-8')
    assert isinstance(res[0], _Sender)
    assert res[1] == 'Some data'

    ch.put_nowait('{"hello": "world"}')
    res = await mpsc.get(decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == {'hello': 'world'}

    ch.put_nowait(b'{"hello": "world"}')
    res = await mpsc.get(encoding='utf-8', decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == {'hello': 'world'}
예제 #37
0
async def test_decode_message_for_pattern(loop):
    mpsc = Receiver(loop)
    ch = mpsc.pattern('*')
    ch.put_nowait((b'channel', b'Some data'))

    res = await mpsc.get(encoding='utf-8')
    assert isinstance(res[0], _Sender)
    assert res[1] == (b'channel', 'Some data')

    ch.put_nowait((b'channel', '{"hello": "world"}'))
    res = await mpsc.get(decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == (b'channel', {'hello': 'world'})

    ch.put_nowait((b'channel', b'{"hello": "world"}'))
    res = await mpsc.get(encoding='utf-8', decoder=json.loads)
    assert isinstance(res[0], _Sender)
    assert res[1] == (b'channel', {'hello': 'world'})
예제 #38
0
async def test_unsubscribe(create_connection, server, loop):
    sub = await create_connection(server.tcp_address, loop=loop)
    pub = await create_connection(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)
    await sub.execute_pubsub('subscribe',
                             mpsc.channel('channel:1'),
                             mpsc.channel('channel:3'))
    res = await pub.execute("publish", "channel:3", "Hello world")
    assert res == 1
    res = await pub.execute("publish", "channel:1", "Hello world")
    assert res == 1
    assert mpsc.is_active

    assert (await mpsc.wait_message()) is True
    ch, msg = await mpsc.get()
    assert ch.name == b'channel:3'
    assert not ch.is_pattern
    assert msg == b"Hello world"

    assert (await mpsc.wait_message()) is True
    ch, msg = await mpsc.get()
    assert ch.name == b'channel:1'
    assert not ch.is_pattern
    assert msg == b"Hello world"

    await sub.execute_pubsub('unsubscribe', 'channel:1')
    assert mpsc.is_active

    res = await pub.execute("publish", "channel:3", "message")
    assert res == 1
    assert (await mpsc.wait_message()) is True
    ch, msg = await mpsc.get()
    assert ch.name == b'channel:3'
    assert not ch.is_pattern
    assert msg == b"message"

    waiter = asyncio.ensure_future(mpsc.get(), loop=loop)
    await sub.execute_pubsub('unsubscribe', 'channel:3')
    assert not mpsc.is_active
    assert await waiter is None
예제 #39
0
async def test_subscriptions(create_connection, server, loop):
    sub = await create_connection(server.tcp_address, loop=loop)
    pub = await create_connection(server.tcp_address, loop=loop)

    mpsc = Receiver(loop=loop)
    await sub.execute_pubsub('subscribe',
                             mpsc.channel('channel:1'),
                             mpsc.channel('channel:3'))
    res = await pub.execute("publish", "channel:3", "Hello world")
    assert res == 1
    res = await pub.execute("publish", "channel:1", "Hello world")
    assert res == 1
    assert mpsc.is_active

    ch, msg = await mpsc.get()
    assert ch.name == b'channel:3'
    assert not ch.is_pattern
    assert msg == b"Hello world"

    ch, msg = await mpsc.get()
    assert ch.name == b'channel:1'
    assert not ch.is_pattern
    assert msg == b"Hello world"
예제 #40
0
def test_listener_pattern(loop):
    mpsc = Receiver(loop=loop)
    assert not mpsc.is_active

    ch_a = mpsc.pattern("*")
    assert isinstance(ch_a, AbcChannel)
    assert mpsc.is_active

    ch_b = mpsc.pattern('*')
    assert ch_a is ch_b
    assert ch_a.name == ch_b.name
    assert ch_a.is_pattern == ch_b.is_pattern
    assert mpsc.is_active

    # remember id; drop refs to objects and create new one;
    ch_a.close()
    assert not ch_a.is_active

    assert not mpsc.is_active
    ch = mpsc.pattern("*")
    assert ch is not ch_a

    assert dict(mpsc.channels) == {}
    assert dict(mpsc.patterns) == {b'*': ch}
예제 #41
0
async def test_pubsub_receiver_stop_on_disconnect(create_redis, server, loop):
    pub = await create_redis(server.tcp_address, loop=loop)
    sub = await create_redis(server.tcp_address, loop=loop)
    sub_name = 'sub-{:X}'.format(id(sub))
    await sub.client_setname(sub_name)
    for sub_info in await pub.client_list():
        if sub_info.name == sub_name:
            break
    assert sub_info.name == sub_name

    mpsc = Receiver(loop=loop)
    await sub.subscribe(mpsc.channel('channel:1'))
    await sub.subscribe(mpsc.channel('channel:2'))
    await sub.psubscribe(mpsc.pattern('channel:*'))

    q = asyncio.Queue(loop=loop)
    EOF = object()

    async def reader():
        async for ch, msg in mpsc.iter(encoding='utf-8'):
            await q.put((ch.name, msg))
        await q.put(EOF)

    tsk = asyncio.ensure_future(reader(), loop=loop)
    await pub.publish_json('channel:1', ['hello'])
    await pub.publish_json('channel:2', ['hello'])
    # receive all messages
    assert await q.get() == (b'channel:1', '["hello"]')
    assert await q.get() == (b'channel:*', (b'channel:1', '["hello"]'))
    assert await q.get() == (b'channel:2', '["hello"]')
    assert await q.get() == (b'channel:*', (b'channel:2', '["hello"]'))

    # XXX: need to implement `client kill`
    assert await pub.execute('client', 'kill', sub_info.addr) in (b'OK', 1)
    await asyncio.wait_for(tsk, timeout=1, loop=loop)
    assert await q.get() is EOF