示例#1
0
    def create_channel_pool(self,
                            pool_size: int = 2,
                            channel_size: int = 10) -> Pool:
        """
        Given the max connection pool size and the max channel size create a channel Pool.

        :param pool_size: Max size for the underlying connection Pool.
        :type pool_size: integer
        :param channel_size: Max size for the channel Pool.
        :type channel_size: integer
        """
        async def get_connection():
            return await connect_robust(
                host=self.amqp.host,
                port=self.amqp.port,
                login=self.amqp.user,
                password=self.amqp.password,
            )

        connection_pool = Pool(get_connection, max_size=pool_size)

        async def get_channel() -> aio_pika.Channel:
            async with connection_pool.acquire() as connection:
                return await connection.channel()

        return Pool(get_channel, max_size=channel_size)
示例#2
0
    async def connection_queue(conf: dict):
        """
        RABBITMQ_HOST: 127.0.0.1
        RABBITMQ_USER: guest
        RABBITMQ_PASSWORD: guest
        RABBITMQ_PORT: 5672
        RABBITMQ_VHOST: /
        :param conf:
        :return: channel_pool
        """
        async def get_connection() -> aio_pika.connect_robust:
            return await aio_pika.connect_robust(
                url=None,
                loop=asyncio.get_event_loop(),
                host=conf.get('RABBITMQ_HOST'),
                port=conf.get('RABBITMQ_PORT'),
                login=conf.get('RABBITMQ_USER'),
                password=conf.get('RABBITMQ_PASSWORD'),
                virtualhost=conf.get('RABBITMQ_VHOST'))

        connection_pool = Pool(get_connection,
                               max_size=conf.get('RABBITMQ_SIZE') or 100,
                               loop=asyncio.get_event_loop())

        async def get_channel() -> aio_pika.Channel:
            async with connection_pool.acquire() as connection:
                return await connection.channel()

        channel_pool = Pool(get_channel,
                            max_size=conf.get('RABBITMQ_SIZE') or 100,
                            loop=asyncio.get_event_loop())

        return channel_pool
示例#3
0
    def setUp(self):
        super().setUp()
        self.counter = set()

        self.pool = Pool(self.create_instance,
                         max_size=self.max_size,
                         loop=self.loop)
示例#4
0
class RabbitMQ(BaseExtension):
    key = "rabbitmq"

    def __init__(
        self,
        vk,
        queue_name: str,
        rabbitmq_url: str = "amqp://*****:*****@127.0.0.1/",
        max_connections: int = 2,
        max_channels: int = 15,
    ):
        if aio_pika:
            self._vk = vk
            self._queue_name = queue_name
            self._url = rabbitmq_url
            self._conn_pool = Pool(
                self.get_connection, max_size=max_connections, loop=vk.loop
            )
            self._chann_pool = Pool(
                self.get_channel, max_size=max_channels, loop=vk.loop
            )
        else:
            raise RuntimeWarning(
                "Please install aio_pika (pip install aio_pika) for use this extension"
            )

    async def get_events(self) -> None:
        pass

    async def get_connection(self):
        return await aio_pika.connect_robust(self._url)

    async def get_channel(self) -> "aio_pika.Channel":
        async with self._conn_pool.acquire() as connection:
            return await connection.channel()

    async def run(self, dp):
        logger.info("RabbitMQ consumer started!")
        async with self._chann_pool.acquire() as channel:  # type: aio_pika.Channel
            await channel.set_qos(10)

            queue = await channel.declare_queue(
                self._queue_name, durable=False, auto_delete=False
            )

            async with queue.iterator() as queue_iter:
                async for message in queue_iter:
                    event = JSON_LIBRARY.loads(message.body.decode())
                    await dp._process_events([event])
                    await message.ack()
示例#5
0
async def test_simple(max_size, loop):
    counter = 0

    async def create_instance():
        nonlocal counter
        await asyncio.sleep(0)
        counter += 1
        return counter

    pool = Pool(create_instance, max_size=max_size, loop=loop)

    async def getter():
        nonlocal counter, pool

        async with pool.acquire() as instance:
            assert instance > 0
            await asyncio.sleep(0.01)
            return counter

    results = await asyncio.gather(*[getter() for _ in range(200)])

    for result in results:
        assert result > -1

    assert counter == max_size
示例#6
0
class TestCaseItemReuse(BaseTestCase):
    max_size = 10
    call_count = max_size * 5

    def setUp(self):
        super().setUp()
        self.counter = set()

        self.pool = Pool(self.create_instance,
                         max_size=self.max_size,
                         loop=self.loop)

    async def create_instance(self):
        obj = object()
        self.counter.add(obj)
        return obj

    async def test_simple(self):
        counter = Counter()

        async def getter():
            nonlocal counter

            async with self.pool.acquire() as instance:
                await asyncio.sleep(0.05, loop=self.loop)
                counter[instance] += 1

        await asyncio.gather(*[getter() for _ in range(self.call_count)],
                             loop=self.loop,
                             return_exceptions=True)

        self.assertEqual(sum(counter.values()), self.call_count)
        self.assertEqual(self.counter, set(counter))
        self.assertEqual(len(set(counter.values())), 1)
示例#7
0
async def main():
    loop = asyncio.get_event_loop()

    async def get_connection():
        return await aio_pika.connect_robust(
            "amqp://*****:*****@localhost/",
            # Use the connection class that does not restore connections
            connection_class=NonRestoringRobustConnection,
        )

    connection_pool = Pool(get_connection, max_size=2, loop=loop)

    async def get_channel() -> aio_pika.Channel:
        async with connection_pool.acquire() as connection:
            return await connection.channel()

    channel_pool = Pool(get_channel, max_size=10, loop=loop)
    queue_name = "pool_queue"

    async def consume():
        async with channel_pool.acquire() as channel:  # type: aio_pika.Channel
            await channel.set_qos(10)

            queue = await channel.declare_queue(queue_name,
                                                durable=False,
                                                auto_delete=False)

            async with queue.iterator() as queue_iter:
                async for message in queue_iter:
                    print(message)
                    await message.ack()

    async def publish():
        async with channel_pool.acquire() as channel:  # type: aio_pika.Channel
            # Reopen channels that have been closed previously
            if channel.is_closed:
                await channel.reopen()
            await channel.default_exchange.publish(
                aio_pika.Message(("Channel: %r" % channel).encode()),
                queue_name,
            )

    async with connection_pool, channel_pool:
        task = loop.create_task(consume())
        await asyncio.wait([publish() for _ in range(10000)])
        await task
示例#8
0
    def pool(self, max_size, instances, loop):
        async def create_instance():
            nonlocal instances

            obj = TestInstanceBase.Instance()
            instances.add(obj)
            return obj

        return Pool(create_instance, max_size=max_size, loop=loop)
示例#9
0
 def __init__(
     self,
     vk,
     queue_name: str,
     rabbitmq_url: str = "amqp://*****:*****@127.0.0.1/",
     max_connections: int = 2,
     max_channels: int = 15,
 ):
     if aio_pika:
         self._vk = vk
         self._queue_name = queue_name
         self._url = rabbitmq_url
         self._conn_pool = Pool(
             self.get_connection, max_size=max_connections, loop=vk.loop
         )
         self._chann_pool = Pool(
             self.get_channel, max_size=max_channels, loop=vk.loop
         )
     else:
         raise RuntimeWarning(
             "Please install aio_pika (pip install aio_pika) for use this extension"
         )
示例#10
0
async def main():
    loop = asyncio.get_event_loop()

    async def get_connection():
        return await aio_pika.connect_robust("amqp://*****:*****@localhost/")

    connection_pool = Pool(get_connection, max_size=2, loop=loop)

    async def get_channel() -> aio_pika.Channel:
        async with connection_pool.acquire() as connection:
            return await connection.channel()

    channel_pool = Pool(get_channel, max_size=10, loop=loop)
    queue_name = "pool_queue"

    async def consume():
        async with channel_pool.acquire() as channel:  # type: aio_pika.Channel
            await channel.set_qos(10)

            queue = await channel.declare_queue(
                queue_name, durable=False, auto_delete=False
            )

            async for message in queue:
                print(message)
                message.ack()

    async def publish():
        async with channel_pool.acquire() as channel:  # type: aio_pika.Channel
            await channel.default_exchange.publish(
                aio_pika.Message(
                    ("Channel: %r" % channel).encode()
                ),
                queue_name,
            )

    task = loop.create_task(consume())
    await asyncio.wait([publish() for _ in range(10000)])
    await task
示例#11
0
    async def get_channel_pool(self):
        url = self._url
        loop = self._loop

        async def get_connection():
            return await aio_pika.connect_robust(url, loop=loop)

        async def get_channel() -> aio_pika.Channel:
            async with connection_pool.acquire() as connection:
                return await connection.channel()

        if url not in GLOBAL_CHANNEL_POOL_MAP:
            connection_pool = Pool(get_connection,
                                   max_size=self._max_pool_size,
                                   loop=loop)
            channel_pool = Pool(get_channel,
                                max_size=self._max_channel_pool_size,
                                loop=loop)
            GLOBAL_CHANNEL_POOL_MAP[url] = channel_pool
        else:
            channel_pool = GLOBAL_CHANNEL_POOL_MAP[url]
        self._channel_pool = channel_pool
        return channel_pool
async def send_msg(pool: Pool, rk: str, msg: Message):
    async with pool.acquire() as channel:
        exchange = rk.split(".")[0]
        topic_exchange = await channel.declare_exchange(
            exchange, ExchangeType.TOPIC)
        await topic_exchange.publish(msg, routing_key=rk)
async def stock_trigger(q_name, init_dict):
    """
    异动处理主逻辑
    :param q_name: 队列名称
    :param init_dict: 30日新高新低
    :return:
    """
    logger = Logger.with_default_handlers(name=q_name)
    # Get the current event loop. If there is no current event loop set in the current OS thread and set_event_loop()
    # has not yet been called, asyncio will create a new event loop and set it as the current one.
    loop = asyncio.get_event_loop()
    # async redis connection pool
    redis_loop = await aioredis.create_pool('redis://{}:{}'.format(REDIS_DICT['host'], REDIS_DICT['port']),
                                            db=REDIS_DICT['db'],
                                            password=REDIS_DICT['password'],
                                            minsize=1, maxsize=10, loop=loop)

    async def get_connection():
        return await aio_pika.connect_robust(
            "amqp://{}:{}@{}:{}{}".format(config['rabbitmq']['username'],
                                          config['rabbitmq']['password'],
                                          config['rabbitmq']['host'],
                                          config['rabbitmq']['port'],
                                          config['rabbitmq']['virtual_host']))

    # Connection pooling
    connection_pool = Pool(get_connection, max_size=2, loop=loop)

    async def get_channel() -> aio_pika.Channel:
        async with connection_pool.acquire() as connection:
            return await connection.channel()

    channel_pool = Pool(get_channel, max_size=10, loop=loop)
    queue_name = q_name

    async def consume():
        async with channel_pool.acquire() as channel:   # type: aio_pika.Channel
            await channel.set_qos(prefetch_count=10)

            queue = await channel.declare_queue(queue_name, durable=True, auto_delete=False)

            async with queue.iterator() as queue_iter:
                async for message in queue_iter:
                    msg = msg2dict(message.body)
                    start_time = time.time()
                    if check_stock_status(msg['status']):
                        task_queue = []
                        msg['name'] = init_dict[msg['code']][0]  # 增加股票代码名称字段
                        # 同一条消息会被5个计算单元消费
                        # 1. 判断是否超越30日高、低, 因为每日只提醒一次,需要与Redis交互查询是否已经发送
                        if not str2bool(init_dict[msg['code']][3]) and \
                                float(msg['close']) > float(init_dict[msg['code']][1]):
                            init_dict[msg['code']][3] = '1'
                            task_queue.append("14601")
                            print("14601", msg)
                        elif not str2bool(init_dict[msg['code']][4]) and \
                                float(msg['close']) < float(init_dict[msg['code']][2]):
                            init_dict[msg['code']][4] = '1'
                            task_queue.append("14602")
                            print("14602", msg)
                        else:
                            pass

                        # 2. 五分钟内涨跌幅达到±1%, 每日提醒多次
                        # 先计算五分钟内涨跌幅值 -- 使用斐波那契堆,最小堆+最大堆
                        # 将时间和价格组合成一个tuple (td, close)
                        init_dict[msg['code']][11], fiveRatio, td = fiveMinuteCal(init_dict[msg['code']][11],
                                                                                        int(msg['td']),
                                                                                        float(msg['close']))
                        # 五分钟涨幅超过1%,因为重复消息五分钟内仅提醒一次, 需要与Redis交互查询是否五分钟内重复触发
                        if fiveRatio >= 1 and fiveMinuteBefore2Int(td, CONSTANT.STOCK_TIMESTAMP_FORMAT) > int(init_dict[msg['code']][9]):
                            init_dict[msg['code']][9] = td
                            task_queue.append("14603:{}".format(fiveRatio))

                        # 五分钟跌幅超过1%,因为重复消息五分钟内仅提醒一次, 需要与Redis交互查询是否五分钟内重复触发
                        if fiveRatio <= -1 and fiveMinuteBefore2Int(td, CONSTANT.STOCK_TIMESTAMP_FORMAT) > int(init_dict[msg['code']][10]):
                            init_dict[msg['code']][10] = td
                            task_queue.append("14604:{}".format(fiveRatio))

                        # 3. 判断当日涨跌幅是否达到±7%,因为每日只提醒一次,需要与Redis交互查询是否已经提醒
                        if not str2bool(init_dict[msg['code']][5]) and \
                                float(msg['riseFallRate']) > CONSTANT.ALARM_QUOTE_CHANGE:
                            init_dict[msg['code']][5] = '1'
                            task_queue.append("14605")
                            print("14605", msg)
                        elif not str2bool(init_dict[msg['code']][6]) and \
                                float(msg['riseFallRate']) < -CONSTANT.ALARM_QUOTE_CHANGE:
                            init_dict[msg['code']][6] = '1'
                            task_queue.append("14606")
                            print("14606", msg)
                        else:
                            pass

                        # 4. 判断当日涨跌停,因为每日提醒多次,需要与Redis交互查询|本地保存一个字典是否已经提醒
                        if not str2bool(init_dict[msg['code']][7]) and \
                                float(msg['close']) >= float(msg['limitHigh']):
                            init_dict[msg['code']][7] = '1'
                            task_queue.append("14607")
                            print("14607", msg)
                        elif not str2bool(init_dict[msg['code']][8]) and \
                                float(msg['close']) <= float(msg['limitLow']):
                            init_dict[msg['code']][8] = '1'
                            task_queue.append("14608")
                            print("14608", msg)
                        else:
                            pass

                        # 5. 首先出现涨跌停标志,然后判断是否出现涨跌停开板
                        if str2bool(init_dict[msg['code']][7]) and float(msg['close']) < float(msg['limitHigh']):
                            init_dict[msg['code']][7] = '0'
                            task_queue.append("14609")
                            print('14609', msg)
                        elif str2bool(init_dict[msg['code']][8]) and float(msg['close']) > float(msg['limitLow']):
                            init_dict[msg['code']][8] = '0'
                            task_queue.append("14610")
                            print('14610', msg)
                        else:
                            pass

                        # interaction redis
                        await aio_redis_dao(redis_loop, task_queue, msg)
                    else:
                        pass
                    print("Consume Time: {}".format(time.time()-start_time))
                    # Confirm message
                    await message.ack()

    task = loop.create_task(consume())
    await task
示例#14
0
async def consume(loop,
                  sql_template=None,
                  logger=None,
                  config=None,
                  consumer_pool_size=10):
    if config is None:
        config = {
            "mq_host":
            os.environ.get('MQ_HOST'),
            "mq_port":
            int(os.environ.get('MQ_PORT', '5672')),
            "mq_vhost":
            os.environ.get('MQ_VHOST'),
            "mq_user":
            os.environ.get('MQ_USER'),
            "mq_pass":
            os.environ.get('MQ_PASS'),
            "mq_queue":
            os.environ.get('MQ_QUEUE'),
            "mq_queue_durable":
            bool(strtobool(os.environ.get('MQ_QUEUE_DURABLE', 'True'))),
            "mq_exchange":
            os.environ.get("MQ_EXCHANGE"),
            "mq_routing_key":
            os.environ.get("MQ_ROUTING_KEY"),
            "db_host":
            os.environ.get('DB_HOST'),
            "db_port":
            int(os.environ.get('DB_PORT', '5432')),
            "db_user":
            os.environ.get('DB_USER'),
            "db_pass":
            os.environ.get('DB_PASS'),
            "db_database":
            os.environ.get('DB_DATABASE'),
            "consumer_pool_size":
            os.environ.get("CONSUMER_POOL_SIZE"),
            "sql_template":
            os.environ.get('SQL_TEMPLATE')
        }

    if sql_template is None:
        sql_template = config.get("sql_template")

    if "consumer_pool_size" in config:
        if config.get("consumer_pool_size"):
            try:
                consumer_pool_size = int(config.get("consumer_pool_size"))
            except TypeError as e:
                if logger:
                    logger.error(f"Invalid pool size: {consumer_pool_size}")
                raise e

    db_pool = await aiopg.create_pool(host=config.get("db_host"),
                                      user=config.get("db_user"),
                                      password=config.get("db_pass"),
                                      database=config.get("db_database"),
                                      port=config.get("db_port"),
                                      minsize=consumer_pool_size,
                                      maxsize=consumer_pool_size * 2)

    async def get_connection():
        return await aio_pika.connect(host=config.get("mq_host"),
                                      port=config.get("mq_port"),
                                      login=config.get("mq_user"),
                                      password=config.get("mq_pass"),
                                      virtualhost=config.get("mq_vhost"),
                                      loop=loop)

    connection_pool = Pool(get_connection,
                           max_size=consumer_pool_size,
                           loop=loop)

    async def get_channel():
        async with connection_pool.acquire() as connection:
            return await connection.channel()

    channel_pool = Pool(get_channel, max_size=consumer_pool_size, loop=loop)

    async def _push_to_dead_letter_queue(message, channel):
        exchange = await channel.get_exchange(config.get("mq_exchange"))
        await exchange.publish(message=aio_pika.Message(
            message.encode("utf-8")),
                               routing_key=config.get("mq_routing_key"))

    async def _consume():
        async with channel_pool.acquire() as channel:
            queue = await channel.declare_queue(
                config.get("mq_queue"),
                durable=config.get("mq_queue_durable"),
                auto_delete=False)

            db_conn = await db_pool.acquire()
            cursor = await db_conn.cursor()

            while True:
                try:
                    m = await queue.get(timeout=5 * consumer_pool_size)
                    message = m.body.decode('utf-8')
                    if logger:
                        logger.debug(f"Message {message} inserting to db")
                    try:
                        await cursor.execute(sql_template, (message, ))
                    except Exception as e:
                        if logger:
                            logger.error(
                                f"DB Error: {e}, pushing message to dead letter queue!"
                            )
                        _push_to_dead_letter_queue(message, channel)
                    else:
                        m.ack()
                except aio_pika.exceptions.QueueEmpty:
                    db_conn.close()
                    if logger:
                        logger.info("Queue empty. Stopping.")
                    break

    async with connection_pool, channel_pool:
        consumer_pool = []
        if logger:
            logger.info("Consumers started")
        for _ in range(consumer_pool_size):
            consumer_pool.append(_consume())

        await asyncio.gather(*consumer_pool)
示例#15
0
async def connect_to_rabbit():
    rmq.connections_pool = Pool(get_connection_rmq, max_size=2)
    rmq.channels_pool = Pool(get_channel, max_size=10)
示例#16
0
async def start():
    connection_pool = Pool(get_connection, max_size=5)
    conns["connection_pool"] = connection_pool

    channel_pool = Pool(get_channel, max_size=10)
    conns["channel_pool"] = channel_pool