def create_channel_pool(self, pool_size: int = 2, channel_size: int = 10) -> Pool: """ Given the max connection pool size and the max channel size create a channel Pool. :param pool_size: Max size for the underlying connection Pool. :type pool_size: integer :param channel_size: Max size for the channel Pool. :type channel_size: integer """ async def get_connection(): return await connect_robust( host=self.amqp.host, port=self.amqp.port, login=self.amqp.user, password=self.amqp.password, ) connection_pool = Pool(get_connection, max_size=pool_size) async def get_channel() -> aio_pika.Channel: async with connection_pool.acquire() as connection: return await connection.channel() return Pool(get_channel, max_size=channel_size)
async def connection_queue(conf: dict): """ RABBITMQ_HOST: 127.0.0.1 RABBITMQ_USER: guest RABBITMQ_PASSWORD: guest RABBITMQ_PORT: 5672 RABBITMQ_VHOST: / :param conf: :return: channel_pool """ async def get_connection() -> aio_pika.connect_robust: return await aio_pika.connect_robust( url=None, loop=asyncio.get_event_loop(), host=conf.get('RABBITMQ_HOST'), port=conf.get('RABBITMQ_PORT'), login=conf.get('RABBITMQ_USER'), password=conf.get('RABBITMQ_PASSWORD'), virtualhost=conf.get('RABBITMQ_VHOST')) connection_pool = Pool(get_connection, max_size=conf.get('RABBITMQ_SIZE') or 100, loop=asyncio.get_event_loop()) async def get_channel() -> aio_pika.Channel: async with connection_pool.acquire() as connection: return await connection.channel() channel_pool = Pool(get_channel, max_size=conf.get('RABBITMQ_SIZE') or 100, loop=asyncio.get_event_loop()) return channel_pool
def setUp(self): super().setUp() self.counter = set() self.pool = Pool(self.create_instance, max_size=self.max_size, loop=self.loop)
class RabbitMQ(BaseExtension): key = "rabbitmq" def __init__( self, vk, queue_name: str, rabbitmq_url: str = "amqp://*****:*****@127.0.0.1/", max_connections: int = 2, max_channels: int = 15, ): if aio_pika: self._vk = vk self._queue_name = queue_name self._url = rabbitmq_url self._conn_pool = Pool( self.get_connection, max_size=max_connections, loop=vk.loop ) self._chann_pool = Pool( self.get_channel, max_size=max_channels, loop=vk.loop ) else: raise RuntimeWarning( "Please install aio_pika (pip install aio_pika) for use this extension" ) async def get_events(self) -> None: pass async def get_connection(self): return await aio_pika.connect_robust(self._url) async def get_channel(self) -> "aio_pika.Channel": async with self._conn_pool.acquire() as connection: return await connection.channel() async def run(self, dp): logger.info("RabbitMQ consumer started!") async with self._chann_pool.acquire() as channel: # type: aio_pika.Channel await channel.set_qos(10) queue = await channel.declare_queue( self._queue_name, durable=False, auto_delete=False ) async with queue.iterator() as queue_iter: async for message in queue_iter: event = JSON_LIBRARY.loads(message.body.decode()) await dp._process_events([event]) await message.ack()
async def test_simple(max_size, loop): counter = 0 async def create_instance(): nonlocal counter await asyncio.sleep(0) counter += 1 return counter pool = Pool(create_instance, max_size=max_size, loop=loop) async def getter(): nonlocal counter, pool async with pool.acquire() as instance: assert instance > 0 await asyncio.sleep(0.01) return counter results = await asyncio.gather(*[getter() for _ in range(200)]) for result in results: assert result > -1 assert counter == max_size
class TestCaseItemReuse(BaseTestCase): max_size = 10 call_count = max_size * 5 def setUp(self): super().setUp() self.counter = set() self.pool = Pool(self.create_instance, max_size=self.max_size, loop=self.loop) async def create_instance(self): obj = object() self.counter.add(obj) return obj async def test_simple(self): counter = Counter() async def getter(): nonlocal counter async with self.pool.acquire() as instance: await asyncio.sleep(0.05, loop=self.loop) counter[instance] += 1 await asyncio.gather(*[getter() for _ in range(self.call_count)], loop=self.loop, return_exceptions=True) self.assertEqual(sum(counter.values()), self.call_count) self.assertEqual(self.counter, set(counter)) self.assertEqual(len(set(counter.values())), 1)
async def main(): loop = asyncio.get_event_loop() async def get_connection(): return await aio_pika.connect_robust( "amqp://*****:*****@localhost/", # Use the connection class that does not restore connections connection_class=NonRestoringRobustConnection, ) connection_pool = Pool(get_connection, max_size=2, loop=loop) async def get_channel() -> aio_pika.Channel: async with connection_pool.acquire() as connection: return await connection.channel() channel_pool = Pool(get_channel, max_size=10, loop=loop) queue_name = "pool_queue" async def consume(): async with channel_pool.acquire() as channel: # type: aio_pika.Channel await channel.set_qos(10) queue = await channel.declare_queue(queue_name, durable=False, auto_delete=False) async with queue.iterator() as queue_iter: async for message in queue_iter: print(message) await message.ack() async def publish(): async with channel_pool.acquire() as channel: # type: aio_pika.Channel # Reopen channels that have been closed previously if channel.is_closed: await channel.reopen() await channel.default_exchange.publish( aio_pika.Message(("Channel: %r" % channel).encode()), queue_name, ) async with connection_pool, channel_pool: task = loop.create_task(consume()) await asyncio.wait([publish() for _ in range(10000)]) await task
def pool(self, max_size, instances, loop): async def create_instance(): nonlocal instances obj = TestInstanceBase.Instance() instances.add(obj) return obj return Pool(create_instance, max_size=max_size, loop=loop)
def __init__( self, vk, queue_name: str, rabbitmq_url: str = "amqp://*****:*****@127.0.0.1/", max_connections: int = 2, max_channels: int = 15, ): if aio_pika: self._vk = vk self._queue_name = queue_name self._url = rabbitmq_url self._conn_pool = Pool( self.get_connection, max_size=max_connections, loop=vk.loop ) self._chann_pool = Pool( self.get_channel, max_size=max_channels, loop=vk.loop ) else: raise RuntimeWarning( "Please install aio_pika (pip install aio_pika) for use this extension" )
async def main(): loop = asyncio.get_event_loop() async def get_connection(): return await aio_pika.connect_robust("amqp://*****:*****@localhost/") connection_pool = Pool(get_connection, max_size=2, loop=loop) async def get_channel() -> aio_pika.Channel: async with connection_pool.acquire() as connection: return await connection.channel() channel_pool = Pool(get_channel, max_size=10, loop=loop) queue_name = "pool_queue" async def consume(): async with channel_pool.acquire() as channel: # type: aio_pika.Channel await channel.set_qos(10) queue = await channel.declare_queue( queue_name, durable=False, auto_delete=False ) async for message in queue: print(message) message.ack() async def publish(): async with channel_pool.acquire() as channel: # type: aio_pika.Channel await channel.default_exchange.publish( aio_pika.Message( ("Channel: %r" % channel).encode() ), queue_name, ) task = loop.create_task(consume()) await asyncio.wait([publish() for _ in range(10000)]) await task
async def get_channel_pool(self): url = self._url loop = self._loop async def get_connection(): return await aio_pika.connect_robust(url, loop=loop) async def get_channel() -> aio_pika.Channel: async with connection_pool.acquire() as connection: return await connection.channel() if url not in GLOBAL_CHANNEL_POOL_MAP: connection_pool = Pool(get_connection, max_size=self._max_pool_size, loop=loop) channel_pool = Pool(get_channel, max_size=self._max_channel_pool_size, loop=loop) GLOBAL_CHANNEL_POOL_MAP[url] = channel_pool else: channel_pool = GLOBAL_CHANNEL_POOL_MAP[url] self._channel_pool = channel_pool return channel_pool
async def send_msg(pool: Pool, rk: str, msg: Message): async with pool.acquire() as channel: exchange = rk.split(".")[0] topic_exchange = await channel.declare_exchange( exchange, ExchangeType.TOPIC) await topic_exchange.publish(msg, routing_key=rk)
async def stock_trigger(q_name, init_dict): """ 异动处理主逻辑 :param q_name: 队列名称 :param init_dict: 30日新高新低 :return: """ logger = Logger.with_default_handlers(name=q_name) # Get the current event loop. If there is no current event loop set in the current OS thread and set_event_loop() # has not yet been called, asyncio will create a new event loop and set it as the current one. loop = asyncio.get_event_loop() # async redis connection pool redis_loop = await aioredis.create_pool('redis://{}:{}'.format(REDIS_DICT['host'], REDIS_DICT['port']), db=REDIS_DICT['db'], password=REDIS_DICT['password'], minsize=1, maxsize=10, loop=loop) async def get_connection(): return await aio_pika.connect_robust( "amqp://{}:{}@{}:{}{}".format(config['rabbitmq']['username'], config['rabbitmq']['password'], config['rabbitmq']['host'], config['rabbitmq']['port'], config['rabbitmq']['virtual_host'])) # Connection pooling connection_pool = Pool(get_connection, max_size=2, loop=loop) async def get_channel() -> aio_pika.Channel: async with connection_pool.acquire() as connection: return await connection.channel() channel_pool = Pool(get_channel, max_size=10, loop=loop) queue_name = q_name async def consume(): async with channel_pool.acquire() as channel: # type: aio_pika.Channel await channel.set_qos(prefetch_count=10) queue = await channel.declare_queue(queue_name, durable=True, auto_delete=False) async with queue.iterator() as queue_iter: async for message in queue_iter: msg = msg2dict(message.body) start_time = time.time() if check_stock_status(msg['status']): task_queue = [] msg['name'] = init_dict[msg['code']][0] # 增加股票代码名称字段 # 同一条消息会被5个计算单元消费 # 1. 判断是否超越30日高、低, 因为每日只提醒一次,需要与Redis交互查询是否已经发送 if not str2bool(init_dict[msg['code']][3]) and \ float(msg['close']) > float(init_dict[msg['code']][1]): init_dict[msg['code']][3] = '1' task_queue.append("14601") print("14601", msg) elif not str2bool(init_dict[msg['code']][4]) and \ float(msg['close']) < float(init_dict[msg['code']][2]): init_dict[msg['code']][4] = '1' task_queue.append("14602") print("14602", msg) else: pass # 2. 五分钟内涨跌幅达到±1%, 每日提醒多次 # 先计算五分钟内涨跌幅值 -- 使用斐波那契堆,最小堆+最大堆 # 将时间和价格组合成一个tuple (td, close) init_dict[msg['code']][11], fiveRatio, td = fiveMinuteCal(init_dict[msg['code']][11], int(msg['td']), float(msg['close'])) # 五分钟涨幅超过1%,因为重复消息五分钟内仅提醒一次, 需要与Redis交互查询是否五分钟内重复触发 if fiveRatio >= 1 and fiveMinuteBefore2Int(td, CONSTANT.STOCK_TIMESTAMP_FORMAT) > int(init_dict[msg['code']][9]): init_dict[msg['code']][9] = td task_queue.append("14603:{}".format(fiveRatio)) # 五分钟跌幅超过1%,因为重复消息五分钟内仅提醒一次, 需要与Redis交互查询是否五分钟内重复触发 if fiveRatio <= -1 and fiveMinuteBefore2Int(td, CONSTANT.STOCK_TIMESTAMP_FORMAT) > int(init_dict[msg['code']][10]): init_dict[msg['code']][10] = td task_queue.append("14604:{}".format(fiveRatio)) # 3. 判断当日涨跌幅是否达到±7%,因为每日只提醒一次,需要与Redis交互查询是否已经提醒 if not str2bool(init_dict[msg['code']][5]) and \ float(msg['riseFallRate']) > CONSTANT.ALARM_QUOTE_CHANGE: init_dict[msg['code']][5] = '1' task_queue.append("14605") print("14605", msg) elif not str2bool(init_dict[msg['code']][6]) and \ float(msg['riseFallRate']) < -CONSTANT.ALARM_QUOTE_CHANGE: init_dict[msg['code']][6] = '1' task_queue.append("14606") print("14606", msg) else: pass # 4. 判断当日涨跌停,因为每日提醒多次,需要与Redis交互查询|本地保存一个字典是否已经提醒 if not str2bool(init_dict[msg['code']][7]) and \ float(msg['close']) >= float(msg['limitHigh']): init_dict[msg['code']][7] = '1' task_queue.append("14607") print("14607", msg) elif not str2bool(init_dict[msg['code']][8]) and \ float(msg['close']) <= float(msg['limitLow']): init_dict[msg['code']][8] = '1' task_queue.append("14608") print("14608", msg) else: pass # 5. 首先出现涨跌停标志,然后判断是否出现涨跌停开板 if str2bool(init_dict[msg['code']][7]) and float(msg['close']) < float(msg['limitHigh']): init_dict[msg['code']][7] = '0' task_queue.append("14609") print('14609', msg) elif str2bool(init_dict[msg['code']][8]) and float(msg['close']) > float(msg['limitLow']): init_dict[msg['code']][8] = '0' task_queue.append("14610") print('14610', msg) else: pass # interaction redis await aio_redis_dao(redis_loop, task_queue, msg) else: pass print("Consume Time: {}".format(time.time()-start_time)) # Confirm message await message.ack() task = loop.create_task(consume()) await task
async def consume(loop, sql_template=None, logger=None, config=None, consumer_pool_size=10): if config is None: config = { "mq_host": os.environ.get('MQ_HOST'), "mq_port": int(os.environ.get('MQ_PORT', '5672')), "mq_vhost": os.environ.get('MQ_VHOST'), "mq_user": os.environ.get('MQ_USER'), "mq_pass": os.environ.get('MQ_PASS'), "mq_queue": os.environ.get('MQ_QUEUE'), "mq_queue_durable": bool(strtobool(os.environ.get('MQ_QUEUE_DURABLE', 'True'))), "mq_exchange": os.environ.get("MQ_EXCHANGE"), "mq_routing_key": os.environ.get("MQ_ROUTING_KEY"), "db_host": os.environ.get('DB_HOST'), "db_port": int(os.environ.get('DB_PORT', '5432')), "db_user": os.environ.get('DB_USER'), "db_pass": os.environ.get('DB_PASS'), "db_database": os.environ.get('DB_DATABASE'), "consumer_pool_size": os.environ.get("CONSUMER_POOL_SIZE"), "sql_template": os.environ.get('SQL_TEMPLATE') } if sql_template is None: sql_template = config.get("sql_template") if "consumer_pool_size" in config: if config.get("consumer_pool_size"): try: consumer_pool_size = int(config.get("consumer_pool_size")) except TypeError as e: if logger: logger.error(f"Invalid pool size: {consumer_pool_size}") raise e db_pool = await aiopg.create_pool(host=config.get("db_host"), user=config.get("db_user"), password=config.get("db_pass"), database=config.get("db_database"), port=config.get("db_port"), minsize=consumer_pool_size, maxsize=consumer_pool_size * 2) async def get_connection(): return await aio_pika.connect(host=config.get("mq_host"), port=config.get("mq_port"), login=config.get("mq_user"), password=config.get("mq_pass"), virtualhost=config.get("mq_vhost"), loop=loop) connection_pool = Pool(get_connection, max_size=consumer_pool_size, loop=loop) async def get_channel(): async with connection_pool.acquire() as connection: return await connection.channel() channel_pool = Pool(get_channel, max_size=consumer_pool_size, loop=loop) async def _push_to_dead_letter_queue(message, channel): exchange = await channel.get_exchange(config.get("mq_exchange")) await exchange.publish(message=aio_pika.Message( message.encode("utf-8")), routing_key=config.get("mq_routing_key")) async def _consume(): async with channel_pool.acquire() as channel: queue = await channel.declare_queue( config.get("mq_queue"), durable=config.get("mq_queue_durable"), auto_delete=False) db_conn = await db_pool.acquire() cursor = await db_conn.cursor() while True: try: m = await queue.get(timeout=5 * consumer_pool_size) message = m.body.decode('utf-8') if logger: logger.debug(f"Message {message} inserting to db") try: await cursor.execute(sql_template, (message, )) except Exception as e: if logger: logger.error( f"DB Error: {e}, pushing message to dead letter queue!" ) _push_to_dead_letter_queue(message, channel) else: m.ack() except aio_pika.exceptions.QueueEmpty: db_conn.close() if logger: logger.info("Queue empty. Stopping.") break async with connection_pool, channel_pool: consumer_pool = [] if logger: logger.info("Consumers started") for _ in range(consumer_pool_size): consumer_pool.append(_consume()) await asyncio.gather(*consumer_pool)
async def connect_to_rabbit(): rmq.connections_pool = Pool(get_connection_rmq, max_size=2) rmq.channels_pool = Pool(get_channel, max_size=10)
async def start(): connection_pool = Pool(get_connection, max_size=5) conns["connection_pool"] = connection_pool channel_pool = Pool(get_channel, max_size=10) conns["channel_pool"] = channel_pool