async def test_pubsub_receiver_iter(create_redis, server, loop): sub = await create_redis(server.tcp_address, loop=loop) pub = await create_redis(server.tcp_address, loop=loop) mpsc = Receiver(loop=loop) async def coro(mpsc): lst = [] async for msg in mpsc.iter(): lst.append(msg) return lst tsk = asyncio.ensure_future(coro(mpsc), loop=loop) snd1, = await sub.subscribe(mpsc.channel('chan:1')) snd2, = await sub.subscribe(mpsc.channel('chan:2')) snd3, = await sub.psubscribe(mpsc.pattern('chan:*')) subscribers = await pub.publish_json('chan:1', {'Hello': 'World'}) assert subscribers > 1 subscribers = await pub.publish_json('chan:2', ['message']) assert subscribers > 1 loop.call_later(0, mpsc.stop) # await asyncio.sleep(0, loop=loop) assert await tsk == [ (snd1, b'{"Hello": "World"}'), (snd3, (b'chan:1', b'{"Hello": "World"}')), (snd2, b'["message"]'), (snd3, (b'chan:2', b'["message"]')), ] assert not mpsc.is_active
class Messager: def __init__(self, inbound, outbound, loop): self.loop = loop self.conn = None self.inbound = inbound self.outbound = outbound self.receiver = None self.task = None self.replies = dict() self.expected = set() async def initialize(self): self.conn = await aioredis.create_redis_pool(("localhost", 6379), encoding="utf-8", maxsize=2) self.receiver = Receiver(loop=self.loop) await self.conn.subscribe(self.receiver.channel(self.inbound)) self.task = self.loop.create_task(self.fetcher()) print( f'Redis connection established, listening on {self.inbound}, sending on {self.outbound}' ) # FIXME - propper logging async def terminate(self): # terminate channels and disconnect from redis self.conn.unsubscribe(self.inbound) self.task.cancel() self.receiver.stop() self.conn.close() await self.conn.wait_closed() async def fetcher(self): async for sender, message in self.receiver.iter(encoding='utf-8', decoder=json.loads): channel = sender.name.decode() if channel == self.inbound: uid = message["uid"] if uid not in self.expected: print("Unexpected message!") print(message) else: self.expected.remove(uid) self.replies[uid] = message async def get_reply(self, data): try: return (await asyncio.wait_for(self._get_reply(data), 10))["reply"] except TimeoutError: raise Redisception("No reply received from the bot!") async def _get_reply(self, data): uid = str(uuid.uuid4()) self.expected.add(uid) data["uid"] = uid await self.conn.publish_json(self.outbound, data) while uid not in self.replies: await asyncio.sleep(0.1) reply = self.replies[uid] del self.replies[uid] return reply
async def test_pubsub_receiver_iter(create_redis, server, event_loop): sub = await create_redis(server.tcp_address) pub = await create_redis(server.tcp_address) mpsc = Receiver() async def coro(mpsc): lst = [] async for msg in mpsc.iter(): lst.append(msg) return lst tsk = asyncio.ensure_future(coro(mpsc)) (snd1, ) = await sub.subscribe(mpsc.channel("chan:1")) (snd2, ) = await sub.subscribe(mpsc.channel("chan:2")) (snd3, ) = await sub.psubscribe(mpsc.pattern("chan:*")) subscribers = await pub.publish_json("chan:1", {"Hello": "World"}) assert subscribers > 1 subscribers = await pub.publish_json("chan:2", ["message"]) assert subscribers > 1 event_loop.call_later(0, mpsc.stop) await asyncio.sleep(0.01) assert await tsk == [ (snd1, b'{"Hello": "World"}'), (snd3, (b"chan:1", b'{"Hello": "World"}')), (snd2, b'["message"]'), (snd3, (b"chan:2", b'["message"]')), ] assert not mpsc.is_active
async def test_stopped(create_connection, server, caplog): sub = await create_connection(server.tcp_address) pub = await create_connection(server.tcp_address) mpsc = Receiver() await sub.execute_pubsub("subscribe", mpsc.channel("channel:1")) assert mpsc.is_active mpsc.stop() caplog.clear() with caplog.at_level("DEBUG", "aioredis"): await pub.execute("publish", "channel:1", b"Hello") await asyncio.sleep(0) assert len(caplog.record_tuples) == 1 # Receiver must have 1 EndOfStream message message = ( "Pub/Sub listener message after stop: " "sender: <_Sender name:b'channel:1', is_pattern:False, receiver:" "<Receiver is_active:False, senders:1, qsize:0>>, data: b'Hello'") assert caplog.record_tuples == [ ("aioredis", logging.WARNING, message), ] # assert (await mpsc.get()) is None with pytest.raises(ChannelClosedError): await mpsc.get() res = await mpsc.wait_message() assert res is False
def __init__(self, bot): super().__init__(bot) bot.loop.create_task(self.init()) self.redis_link: aioredis.Redis = None self.receiver = Receiver(loop=bot.loop) self.handlers = dict(question=self.question, update=self.update, user_guilds=self.user_guilds, user_guilds_end=self.user_guilds_end, guild_info_watch=self.guild_info_watch, guild_info_watch_end=self.guild_info_watch_end) self.question_handlers = dict( heartbeat=self.still_spinning, user_info=self.user_info_request, get_guild_settings=self.get_guild_settings, save_guild_settings=self.save_guild_settings, replace_guild_settings=self.replace_guild_settings, setup_mute=self.setup_mute, cleanup_mute=self.cleanup_mute, cache_info=self.cache_info, guild_user_perms=self.guild_user_perms) # The last time we received a heartbeat, the current attempt number, how many times we have notified the owner self.last_dash_heartbeat = [time.time(), 0, 0] self.last_update = datetime.now() self.to_log = dict() self.update_message = None if Configuration.get_master_var( "TRANSLATIONS", dict(SOURCE="SITE", CHANNEL=0, KEY="", LOGIN="", WEBROOT=""))["SOURCE"] == 'CROWDIN': self.handlers["crowdin_webhook"] = self.crowdin_webhook self.task = self._receiver()
async def _update_error_info(self): aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() key = ray.gcs_utils.RAY_ERROR_PUBSUB_PATTERN pattern = receiver.pattern(key) await aioredis_client.psubscribe(pattern) logger.info("Subscribed to %s", key) async for sender, msg in receiver.iter(): try: _, data = msg pubsub_msg = ray.gcs_utils.PubSubMessage.FromString(data) error_data = ray.gcs_utils.ErrorTableData.FromString( pubsub_msg.data) message = error_data.error_message message = re.sub(r"\x1b\[\d+m", "", message) match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message) if match: pid = match.group(1) ip = match.group(2) errs_for_ip = DataSource.ip_and_pid_to_errors.get(ip, {}) pid_errors = errs_for_ip.get(pid, []) pid_errors.append({ "message": message, "timestamp": error_data.timestamp, "type": error_data.type }) errs_for_ip[pid] = pid_errors DataSource.ip_and_pid_to_errors[ip] = errs_for_ip logger.info(f"Received error entry for {ip} {pid}") except Exception: logger.exception("Error receiving error info.")
async def _update_log_info(self): def process_log_batch(log_batch): ip = log_batch["ip"] pid = str(log_batch["pid"]) if pid != "autoscaler": logs_for_ip = dict(DataSource.ip_and_pid_to_logs.get(ip, {})) logs_for_pid = list(logs_for_ip.get(pid, [])) logs_for_pid.extend(log_batch["lines"]) # Only cache upto MAX_LOGS_TO_CACHE logs_length = len(logs_for_pid) if logs_length > MAX_LOGS_TO_CACHE * LOG_PRUNE_THREASHOLD: offset = logs_length - MAX_LOGS_TO_CACHE del logs_for_pid[:offset] logs_for_ip[pid] = logs_for_pid DataSource.ip_and_pid_to_logs[ip] = logs_for_ip logger.info(f"Received a log for {ip} and {pid}") aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() channel = receiver.channel(gcs_utils.LOG_FILE_CHANNEL) await aioredis_client.subscribe(channel) logger.info("Subscribed to %s", channel) async for sender, msg in receiver.iter(): try: data = json.loads(ray._private.utils.decode(msg)) data["pid"] = str(data["pid"]) process_log_batch(data) except Exception: logger.exception("Error receiving log from Redis.")
async def test_stopped(create_connection, server, loop): sub = await create_connection(server.tcp_address, loop=loop) pub = await create_connection(server.tcp_address, loop=loop) mpsc = Receiver(loop=loop) await sub.execute_pubsub('subscribe', mpsc.channel('channel:1')) assert mpsc.is_active mpsc.stop() with logs('aioredis', 'DEBUG') as cm: await pub.execute('publish', 'channel:1', b'Hello') await asyncio.sleep(0, loop=loop) assert len(cm.output) == 1 # Receiver must have 1 EndOfStream message warn_messaege = ( "WARNING:aioredis:Pub/Sub listener message after stop: " "sender: <_Sender name:b'channel:1', is_pattern:False, receiver:" "<Receiver is_active:False, senders:1, qsize:0>>, data: b'Hello'" ) assert cm.output == [warn_messaege] # assert (await mpsc.get()) is None with pytest.raises(ChannelClosedError): await mpsc.get() res = await mpsc.wait_message() assert res is False
async def redis_relay(websocket): conn = await aioredis.create_connection(("localhost", 6379)) receiver = Receiver() conn.execute_pubsub("subscribe", receiver.channel("marketupdates")) while await receiver.wait_message(): *_, message = await receiver.get() await websocket.send(message.decode())
async def test_stopped(create_connection, server, loop): sub = await create_connection(server.tcp_address, loop=loop) pub = await create_connection(server.tcp_address, loop=loop) mpsc = Receiver(loop=loop) await sub.execute_pubsub('subscribe', mpsc.channel('channel:1')) assert mpsc.is_active mpsc.stop() with logs('aioredis', 'DEBUG') as cm: await pub.execute('publish', 'channel:1', b'Hello') await asyncio.sleep(0, loop=loop) assert len(cm.output) == 1 # Receiver must have 1 EndOfStream message warn_messaege = ( "WARNING:aioredis:Pub/Sub listener message after stop: " "sender: <_Sender name:b'channel:1', is_pattern:False, receiver:" "<Receiver is_active:False, senders:1, qsize:0>>, data: b'Hello'") assert cm.output == [warn_messaege] # assert (await mpsc.get()) is None with pytest.raises(ChannelClosedError): await mpsc.get() res = await mpsc.wait_message() assert res is False
async def connect( self, *, conn_retries: int = 5, conn_retry_delay: int = 1, retry: int = 0 ) -> "Broadcast": try: self._pub_conn = await aioredis.create_redis(self.connection_url) self._sub_conn = await aioredis.create_redis(self.connection_url) self._receiver = Receiver() except (ConnectionError, aioredis.RedisError, asyncio.TimeoutError) as e: if retry < conn_retries: logger.warning( "Redis connection error %s %s %s, %d retries remaining...", self.connection_url, e.__class__.__name__, e, conn_retries - retry, ) await asyncio.sleep(conn_retry_delay) else: logger.error("Connecting to Redis failed") raise else: if retry > 0: logger.info("Redis connection successful") return self return await self.connect( conn_retries=conn_retries, conn_retry_delay=conn_retry_delay, retry=retry + 1, )
def __init__(self, bot) -> None: self.bot: GearBot = bot bot.loop.create_task(self.init()) self.redis_link = None self.receiver = Receiver(loop=bot.loop) self.handlers = dict(guild_perm_request=self.guild_perm_request) self.task = self._receiver()
async def before_server_start(_sanic, loop): mpsc = Receiver(loop=loop) redis = await storage.get_async_redis_pool(loop) await redis.subscribe( mpsc.channel('scoreboard'), mpsc.channel('stolen_flags'), ) sio.start_background_task(background_task, mpsc)
async def receive(self): print('Receive from redis') sub = await aioredis.create_connection(('localhost', 6379)) receiver = Receiver() sub.execute_pubsub('subscribe', receiver.channel(self.channel)) while (await receiver.wait_message()): msg = await receiver.get() print("Got Message:", msg)
async def run(self, loop, **kwargs): aredis = await create_aredis() mpsc = Receiver() pattern = mpsc.pattern("TICK:*") await aredis.psubscribe(pattern) logger.info("subscribed") async for channel, data in mpsc.iter(): data = ujson.loads(data[1]) self.handle_tick(data)
async def subscribe(self): if not self._subscribe: redis = await self.redis # self._subscribe = await redis.subscribe('ws') r = Receiver(loop=self.loop) for ch in self._channels: await redis.subscribe(r.channel(ch)) self._subscribe = r return self._subscribe
async def _connect_subscr(self) -> None: async with self.redis_connect_subscr: self.redis_subscr = await create_redis( self.cfg.url, encoding=self.encoding, ) self.mpsc = Receiver(loop=app.loop) await self.redis_subscr.subscribe( self.mpsc.channel(self.cfg.channel)) self._reader_fut = asyncio.ensure_future(self._reader(self.mpsc))
async def subscribe_to_channel(loop, redis_pool): mpsc = Receiver(loop=loop) logger.info("Aquiring redis connection from pool") connection = await redis_pool.acquire() logger.info("Subscribing to redis channel: {}".format(websocket_channel)) await connection.execute_pubsub("subscribe", mpsc.channel(websocket_channel)) return mpsc
async def redis_listener(): redis = await aioredis.create_redis('redis://redishost:6379') receiver = Receiver() await redis.subscribe(receiver.channel('sockets:notification:message'), receiver.channel('sockets:notification:message2')) while await receiver.wait_message(): sender, msg = await receiver.get() for sid in users: await sio.emit('notification:message', {'data': str(msg)}, room=sid)
async def connect(self) -> None: if self._pub_conn or self._sub_conn or self._msg_queue: logger.warning("connections are already setup but connect called again; not doing anything") return self._pub_conn = await aioredis.create_redis(self.conn_url) self._sub_conn = await aioredis.create_redis(self.conn_url) self._msg_queue = asyncio.Queue() # must be created here, to get proper event loop self._mpsc = Receiver() self._reader_task = asyncio.create_task(self._reader())
async def initialize(self): self.conn = await aioredis.create_redis_pool(("localhost", 6379), encoding="utf-8", maxsize=2) self.receiver = Receiver(loop=self.loop) await self.conn.subscribe(self.receiver.channel(self.inbound)) self.task = self.loop.create_task(self.fetcher()) print( f'Redis connection established, listening on {self.inbound}, sending on {self.outbound}' ) # FIXME - propper logging
async def _event_listener(self, channel): async for _ in AsyncCirculator(): async with self._redis_pool.get_client() as cache: receiver = Receiver() await cache.subscribe(receiver.channel(channel)) async for channel, message in receiver.iter(): await self._event_assigner(channel, message)
async def _update_jobs(self): # Subscribe job channel. aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() key = f"{job_consts.JOB_CHANNEL}:*" pattern = receiver.pattern(key) await aioredis_client.psubscribe(pattern) logger.info("Subscribed to %s", key) # Get all job info. while True: try: logger.info("Getting all job info from GCS.") request = gcs_service_pb2.GetAllJobInfoRequest() reply = await self._gcs_job_info_stub.GetAllJobInfo(request, timeout=5) if reply.status.code == 0: jobs = {} for job_table_data in reply.job_info_list: data = job_table_data_to_dict(job_table_data) jobs[data["jobId"]] = data # Update jobs. DataSource.jobs.reset(jobs) logger.info("Received %d job info from GCS.", len(jobs)) break else: raise Exception( f"Failed to GetAllJobInfo: {reply.status.message}") except Exception: logger.exception("Error Getting all job info from GCS.") await asyncio.sleep( job_consts.RETRY_GET_ALL_JOB_INFO_INTERVAL_SECONDS) # Receive jobs from channel. async for sender, msg in receiver.iter(): try: _, data = msg pubsub_message = ray.gcs_utils.PubSubMessage.FromString(data) message = ray.gcs_utils.JobTableData.FromString( pubsub_message.data) job_id = ray._raylet.JobID(message.job_id) if job_id.is_submitted_from_dashboard(): job_table_data = job_table_data_to_dict(message) job_id = job_table_data["jobId"] # Update jobs. DataSource.jobs[job_id] = job_table_data else: logger.info( "Ignore job %s which is not submitted from dashboard.", job_id.hex()) except Exception: logger.exception("Error receiving job info.")
async def main(): await redis.connect() mpsc = Receiver(loop=asyncio.get_event_loop()) asyncio.ensure_future(reader(mpsc)) await redis._redis.subscribe(mpsc.channel('channel:1')) while True: try: await asyncio.sleep(10) print('hearbeat', flush=True) except Exception as ex: print(ex, flush=True) mpsc.stop() await redis.disconnect()
async def subscribe_chat(self): redis_for_pubsub = await aioredis.create_redis(self.redis.address, db=self.redis.db) receiver = Receiver() channel = receiver.channel(CHAT_CHANNEL) await redis_for_pubsub.subscribe(channel) async def _get(): return (await receiver.get())[1] yield channel, _get await redis_for_pubsub.unsubscribe(CHAT_CHANNEL)
async def _update_error_info(self): def process_error(error_data): message = error_data.error_message message = re.sub(r"\x1b\[\d+m", "", message) match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message) if match: pid = match.group(1) ip = match.group(2) errs_for_ip = dict(DataSource.ip_and_pid_to_errors.get(ip, {})) pid_errors = list(errs_for_ip.get(pid, [])) pid_errors.append( { "message": message, "timestamp": error_data.timestamp, "type": error_data.type, } ) errs_for_ip[pid] = pid_errors DataSource.ip_and_pid_to_errors[ip] = errs_for_ip logger.info(f"Received error entry for {ip} {pid}") if self._dashboard_head.gcs_error_subscriber: while True: try: ( _, error_data, ) = await self._dashboard_head.gcs_error_subscriber.poll() if error_data is None: continue process_error(error_data) except Exception: logger.exception("Error receiving error info from GCS.") else: aioredis_client = self._dashboard_head.aioredis_client receiver = Receiver() key = gcs_utils.RAY_ERROR_PUBSUB_PATTERN pattern = receiver.pattern(key) await aioredis_client.psubscribe(pattern) logger.info("Subscribed to %s", key) async for _, msg in receiver.iter(): try: _, data = msg pubsub_msg = gcs_utils.PubSubMessage.FromString(data) error_data = gcs_utils.ErrorTableData.FromString(pubsub_msg.data) process_error(error_data) except Exception: logger.exception("Error receiving error info from Redis.")
async def test_decode_message_error(loop): mpsc = Receiver(loop) ch = mpsc.channel('channel:1') ch.put_nowait(b'{"hello": "world"}') unexpected = (mock.ANY, {'hello': 'world'}) with pytest.raises(TypeError): assert (await mpsc.get(decoder=json.loads)) == unexpected ch = mpsc.pattern('*') ch.put_nowait((b'channel', b'{"hello": "world"}')) unexpected = (mock.ANY, b'channel', {'hello': 'world'}) with pytest.raises(TypeError): assert (await mpsc.get(decoder=json.loads)) == unexpected
async def test_decode_message_error(): mpsc = Receiver() ch = mpsc.channel("channel:1") ch.put_nowait(b'{"hello": "world"}') unexpected = (mock.ANY, {"hello": "world"}) with pytest.raises(TypeError): assert (await mpsc.get(decoder=json.loads)) == unexpected ch = mpsc.pattern("*") ch.put_nowait((b"channel", b'{"hello": "world"}')) unexpected = (mock.ANY, b"channel", {"hello": "world"}) with pytest.raises(TypeError): assert (await mpsc.get(decoder=json.loads)) == unexpected
async def run(self): p = self._dashboard_head.aioredis_client mpsc = Receiver() reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX) await p.psubscribe(mpsc.pattern(reporter_key)) logger.info("Subscribed to {}".format(reporter_key)) async for sender, msg in mpsc.iter(): try: _, data = msg data = json.loads(ray.utils.decode(data)) DataSource.node_physical_stats[data["ip"]] = data except Exception as ex: logger.exception(ex)
async def test_pubsub_receiver_call_stop_with_empty_queue( create_redis, server, loop): sub = await create_redis(server.tcp_address, loop=loop) mpsc = Receiver(loop=loop) # FIXME: currently at least one subscriber is needed snd1, = await sub.subscribe(mpsc.channel('chan:1')) now = loop.time() loop.call_later(.5, mpsc.stop) async for i in mpsc.iter(): # noqa (flake8 bug with async for) assert False, "StopAsyncIteration not raised" dt = loop.time() - now assert dt <= 1.5 assert not mpsc.is_active
async def listen_redis(self, websocket, channels): channel = channels.pop(0) try: mpsc = Receiver(loop=asyncio.get_event_loop()) await self.sub.subscribe( mpsc.channel(channel), *(mpsc.channel(channel) for channel in channels)) async for channel, msg in mpsc.iter(): if websocket.client_state == WebSocketState.CONNECTED: await websocket.send_bytes(msg) except: import traceback traceback.print_exc() await self.close_connections() finally: logger.info('Connection closed')
async def test_wait_message(create_connection, server, loop): sub = await create_connection(server.tcp_address, loop=loop) pub = await create_connection(server.tcp_address, loop=loop) mpsc = Receiver(loop=loop) await sub.execute_pubsub('subscribe', mpsc.channel('channel:1')) fut = asyncio.ensure_future(mpsc.wait_message(), loop=loop) assert not fut.done() await asyncio.sleep(0, loop=loop) assert not fut.done() await pub.execute('publish', 'channel:1', 'hello') await asyncio.sleep(0, loop=loop) # read in connection await asyncio.sleep(0, loop=loop) # call Future.set_result assert fut.done() res = await fut assert res is True
async def test_decode_message(loop): mpsc = Receiver(loop) ch = mpsc.channel('channel:1') ch.put_nowait(b'Some data') res = await mpsc.get(encoding='utf-8') assert isinstance(res[0], _Sender) assert res[1] == 'Some data' ch.put_nowait('{"hello": "world"}') res = await mpsc.get(decoder=json.loads) assert isinstance(res[0], _Sender) assert res[1] == {'hello': 'world'} ch.put_nowait(b'{"hello": "world"}') res = await mpsc.get(encoding='utf-8', decoder=json.loads) assert isinstance(res[0], _Sender) assert res[1] == {'hello': 'world'}
async def test_decode_message_for_pattern(loop): mpsc = Receiver(loop) ch = mpsc.pattern('*') ch.put_nowait((b'channel', b'Some data')) res = await mpsc.get(encoding='utf-8') assert isinstance(res[0], _Sender) assert res[1] == (b'channel', 'Some data') ch.put_nowait((b'channel', '{"hello": "world"}')) res = await mpsc.get(decoder=json.loads) assert isinstance(res[0], _Sender) assert res[1] == (b'channel', {'hello': 'world'}) ch.put_nowait((b'channel', b'{"hello": "world"}')) res = await mpsc.get(encoding='utf-8', decoder=json.loads) assert isinstance(res[0], _Sender) assert res[1] == (b'channel', {'hello': 'world'})
async def test_unsubscribe(create_connection, server, loop): sub = await create_connection(server.tcp_address, loop=loop) pub = await create_connection(server.tcp_address, loop=loop) mpsc = Receiver(loop=loop) await sub.execute_pubsub('subscribe', mpsc.channel('channel:1'), mpsc.channel('channel:3')) res = await pub.execute("publish", "channel:3", "Hello world") assert res == 1 res = await pub.execute("publish", "channel:1", "Hello world") assert res == 1 assert mpsc.is_active assert (await mpsc.wait_message()) is True ch, msg = await mpsc.get() assert ch.name == b'channel:3' assert not ch.is_pattern assert msg == b"Hello world" assert (await mpsc.wait_message()) is True ch, msg = await mpsc.get() assert ch.name == b'channel:1' assert not ch.is_pattern assert msg == b"Hello world" await sub.execute_pubsub('unsubscribe', 'channel:1') assert mpsc.is_active res = await pub.execute("publish", "channel:3", "message") assert res == 1 assert (await mpsc.wait_message()) is True ch, msg = await mpsc.get() assert ch.name == b'channel:3' assert not ch.is_pattern assert msg == b"message" waiter = asyncio.ensure_future(mpsc.get(), loop=loop) await sub.execute_pubsub('unsubscribe', 'channel:3') assert not mpsc.is_active assert await waiter is None
async def test_subscriptions(create_connection, server, loop): sub = await create_connection(server.tcp_address, loop=loop) pub = await create_connection(server.tcp_address, loop=loop) mpsc = Receiver(loop=loop) await sub.execute_pubsub('subscribe', mpsc.channel('channel:1'), mpsc.channel('channel:3')) res = await pub.execute("publish", "channel:3", "Hello world") assert res == 1 res = await pub.execute("publish", "channel:1", "Hello world") assert res == 1 assert mpsc.is_active ch, msg = await mpsc.get() assert ch.name == b'channel:3' assert not ch.is_pattern assert msg == b"Hello world" ch, msg = await mpsc.get() assert ch.name == b'channel:1' assert not ch.is_pattern assert msg == b"Hello world"
def test_listener_pattern(loop): mpsc = Receiver(loop=loop) assert not mpsc.is_active ch_a = mpsc.pattern("*") assert isinstance(ch_a, AbcChannel) assert mpsc.is_active ch_b = mpsc.pattern('*') assert ch_a is ch_b assert ch_a.name == ch_b.name assert ch_a.is_pattern == ch_b.is_pattern assert mpsc.is_active # remember id; drop refs to objects and create new one; ch_a.close() assert not ch_a.is_active assert not mpsc.is_active ch = mpsc.pattern("*") assert ch is not ch_a assert dict(mpsc.channels) == {} assert dict(mpsc.patterns) == {b'*': ch}
async def test_pubsub_receiver_stop_on_disconnect(create_redis, server, loop): pub = await create_redis(server.tcp_address, loop=loop) sub = await create_redis(server.tcp_address, loop=loop) sub_name = 'sub-{:X}'.format(id(sub)) await sub.client_setname(sub_name) for sub_info in await pub.client_list(): if sub_info.name == sub_name: break assert sub_info.name == sub_name mpsc = Receiver(loop=loop) await sub.subscribe(mpsc.channel('channel:1')) await sub.subscribe(mpsc.channel('channel:2')) await sub.psubscribe(mpsc.pattern('channel:*')) q = asyncio.Queue(loop=loop) EOF = object() async def reader(): async for ch, msg in mpsc.iter(encoding='utf-8'): await q.put((ch.name, msg)) await q.put(EOF) tsk = asyncio.ensure_future(reader(), loop=loop) await pub.publish_json('channel:1', ['hello']) await pub.publish_json('channel:2', ['hello']) # receive all messages assert await q.get() == (b'channel:1', '["hello"]') assert await q.get() == (b'channel:*', (b'channel:1', '["hello"]')) assert await q.get() == (b'channel:2', '["hello"]') assert await q.get() == (b'channel:*', (b'channel:2', '["hello"]')) # XXX: need to implement `client kill` assert await pub.execute('client', 'kill', sub_info.addr) in (b'OK', 1) await asyncio.wait_for(tsk, timeout=1, loop=loop) assert await q.get() is EOF