Esempio n. 1
0
    async def call(self,
                   method_name,
                   kwargs: dict = None,
                   *,
                   expiration: int = None,
                   priority: int = 5,
                   delivery_mode: DeliveryMode = RPC.DELIVERY_MODE):
        with tracer.trace(method_name, service='rabbitmq'):
            future = self.create_future()

            headers = {'From': self.result_queue.name}
            context = current_trace_context()
            self.DDTRACE_PROPAGATOR.inject(context, headers)

            message = Message(body=self.serialize(kwargs or {}),
                              type=RPCMessageTypes.call.value,
                              timestamp=time.time(),
                              priority=priority,
                              correlation_id=id(future),
                              delivery_mode=delivery_mode,
                              reply_to=self.result_queue.name,
                              headers=headers)

            if expiration is not None:
                message.expiration = expiration

            await self.channel.default_exchange.publish(
                message, routing_key=method_name, mandatory=True)
            return await future
Esempio n. 2
0
    def on_callback(self, message: IncomingMessage):
        func_name = message.headers['FuncName']
        if func_name not in self.routes:
            return

        payload = self.deserialize(message.body)
        func = self.routes[func_name]

        try:
            result = yield from self._execute(func, payload)
            result = self.serialize(result)
            message_type = 'result'
        except Exception as e:
            result = self.serialize(e)
            message_type = 'error'

        result_message = Message(
            result,
            delivery_mode=message.delivery_mode,
            correlation_id=message.correlation_id,
            timestamp=time.time(),
            type=message_type,
        )

        yield from self.channel.default_exchange.publish(result_message,
                                                         message.reply_to,
                                                         mandatory=False)

        message.ack()
Esempio n. 3
0
    async def do_create_task(self,
                             data: dict,
                             routing_key: str = None,
                             queue: str = None):
        """ Creates a new task for the worker """
        logging.info("Creating TASK to {}".format(routing_key or queue))
        if not (routing_key or queue):
            raise Exception("args routing_key or queue required")

        message = Message(
            body=self.serialize(data or {}),
            content_type=self.CONTENT_TYPE,
            delivery_mode=self.DELIVERY_MODE,
        )

        try:
            await self.do_publish(routing_key=routing_key,
                                  message=message,
                                  queue=queue)
        except Exception as e:
            logging.error(e)
            await self.reconnect()
            await asyncio.sleep(2, loop=self.loop)
            # await self.connect()
            await self.do_publish(routing_key=routing_key,
                                  message=message,
                                  queue=queue)
Esempio n. 4
0
    def call(self,
             func_name,
             kwargs: dict = None,
             *,
             expiration: int = None,
             priority: int = 128,
             delivery_mode: DeliveryMode = DeliveryMode.NOT_PERSISTENT):
        future = self._create_future()
        message = Message(
            body=self.serialize(kwargs or {}),
            type='call',
            timestamp=time.time(),
            expiration=expiration,
            priority=priority,
            correlation_id=id(future),
            delivery_mode=delivery_mode,
            reply_to=self.result_queue.name,
            headers={
                'From': self.result_queue.name,
                'FuncName': func_name
            },
        )

        yield from self.channel.default_exchange.publish(
            message, routing_key=self.queue_name, mandatory=True)

        return (yield from future)
Esempio n. 5
0
    def on_call_message(self, method_name: str, message: IncomingMessage):
        if method_name not in self.routes:
            log.warning("Method %r not registered in %r", method_name, self)
            return

        payload = self.deserialize(message.body)
        func = self.routes[method_name]

        try:
            result = yield from self.execute(func, payload)
            result = self.serialize(result)
            message_type = 'result'
        except Exception as e:
            result = self.serialize_exception(e)
            message_type = 'error'

        result_message = Message(
            result,
            delivery_mode=message.delivery_mode,
            correlation_id=message.correlation_id,
            timestamp=time.time(),
            type=message_type,
        )

        yield from self.channel.default_exchange.publish(result_message,
                                                         message.reply_to,
                                                         mandatory=False)

        message.ack()
Esempio n. 6
0
    async def call(
        self,
        method_name,
        kwargs: Optional[Dict[Hashable, Any]] = None,
        *,
        expiration: Optional[int] = None,
        priority: int = 5,
        delivery_mode: DeliveryMode = DELIVERY_MODE
    ):
        """ Call remote method and awaiting result.

        :param method_name: Name of method
        :param kwargs: Methos kwargs
        :param expiration:
            If not `None` messages which staying in queue longer
            will be returned and :class:`asyncio.TimeoutError` will be raised.
        :param priority: Message priority
        :param delivery_mode: Call message delivery mode
        :raises asyncio.TimeoutError: when message expired
        :raises CancelledError: when called :func:`RPC.cancel`
        :raises RuntimeError: internal error
        """

        future = self.create_future()

        message = Message(
            body=self.serialize(kwargs or {}),
            type=RPCMessageTypes.call.value,
            timestamp=time.time(),
            priority=priority,
            correlation_id=id(future),
            delivery_mode=delivery_mode,
            reply_to=self.result_queue.name,
            headers={"From": self.result_queue.name},
        )

        if expiration is not None:
            message.expiration = expiration

        log.debug("Publishing calls for %s(%r)", method_name, kwargs)
        await self.channel.default_exchange.publish(
            message, routing_key=method_name, mandatory=True,
        )

        log.debug("Waiting RPC result for %s(%r)", method_name, kwargs)
        return await future
async def test_robust_duplicate_queue(
    connection: aio_pika.RobustConnection,
    declare_exchange: Callable,
    declare_queue: Callable,
    loop: asyncio.AbstractEventLoop,
    proxy: TCPProxy,
    create_task: Callable,
):
    queue_name = "test"

    channel1 = await connection.channel()
    channel2 = await connection.channel()

    reconnect_event = asyncio.Event()
    connection.reconnect_callbacks.add(
        lambda *_: reconnect_event.set(), weak=False
    )

    shared = []
    queue1 = await declare_queue(queue_name, channel=channel1, cleanup=False)
    queue2 = await declare_queue(queue_name, channel=channel2, cleanup=False)

    async def reader(queue):
        nonlocal shared
        async with queue.iterator() as q:
            async for message in q:
                shared.append(message)
                await message.ack()

    create_task(reader(queue1))
    create_task(reader(queue2))

    for _ in range(5):
        await channel2.default_exchange.publish(
            aio_pika.Message(b""), queue_name,
        )

    logging.info("Disconnect all clients")
    await proxy.disconnect_all()

    await reconnect_event.wait()

    logging.info("Waiting connections")
    await asyncio.wait([
        channel1._connection.ready(),
        channel2._connection.ready(),
    ])

    for _ in range(5):
        await channel2.default_exchange.publish(
            Message(b""), queue_name,
        )

    while len(shared) < 10:
        await asyncio.sleep(0.1)

    assert len(shared) == 10
Esempio n. 8
0
    async def test_robust_reconnect(self):
        channel1 = await self.create_channel()
        channel2 = await self.create_channel()

        shared = []
        queue = await channel1.declare_queue()

        async def reader():
            nonlocal shared
            async with queue.iterator() as q:
                async for message in q:
                    shared.append(message)
                    await message.ack()

        reader_task = self.loop.create_task(reader())
        self.addCleanup(reader_task.cancel)

        for _ in range(5):
            await channel2.default_exchange.publish(
                Message(b''),
                queue.name,
            )

        logging.info("Disconnect all clients")
        await self.proxy.disconnect()

        logging.info("Waiting for reconnect")
        await asyncio.sleep(5)

        logging.info("Waiting connections")
        await asyncio.wait(
            [channel1._connection.ready(),
             channel2._connection.ready()])

        for _ in range(5):
            await channel2.default_exchange.publish(
                Message(b''),
                queue.name,
            )

        while len(shared) < 10:
            await asyncio.sleep(0.1)

        assert len(shared) == 10
Esempio n. 9
0
    def call(self, func_name: str, kwargs=None, priority=128):
        message = Message(body=self.serialize(kwargs or {}),
                          content_type=self.CONTENT_TYPE,
                          delivery_mode=self.DELIVERY_MODE,
                          priority=priority,
                          headers={'FuncName': func_name})

        yield from self.channel.default_exchange.publish(message,
                                                         self.queue_name,
                                                         mandatory=True)
Esempio n. 10
0
    def create_task(self, channel_name: str, kwargs=None, **message_kwargs):
        """ Creates a new task for the worker """
        message = Message(body=self.serialize(kwargs or {}),
                          content_type=self.CONTENT_TYPE,
                          delivery_mode=self.DELIVERY_MODE,
                          **message_kwargs)

        yield from self.channel.default_exchange.publish(message,
                                                         channel_name,
                                                         mandatory=True)
Esempio n. 11
0
async def send_message_to_queue(queue_connection: RobustConnection,
                                queue_name: str, message: bytes):
    channel = queue_connection.channel()
    try:
        await channel.initialize()
        await channel.declare_queue(queue_name, durable=True)
        await channel.default_exchange.publish(
            Message(message, delivery_mode=DeliveryMode.PERSISTENT),
            queue_name)
    finally:
        await channel.close()
Esempio n. 12
0
 async def rpc_call_async(
     self,
     method: str,
     kwargs: dict,
     expiration: int = 10,
     priority: RPCCallPriority = RPCCallPriority.MEDIUM,
 ) -> bool:
     message = Message(
         body=self.rpc.serialize(kwargs or {}),
         type=RPCMessageTypes.call.value,
         timestamp=time(),
         priority=priority.value,
         delivery_mode=self.rpc.DELIVERY_MODE,
     )
     if expiration is not None:
         message.expiration = expiration
     response = await self.channel.default_exchange.publish(
         message,
         routing_key=method,
         mandatory=True,
     )
     return isinstance(response, Basic.Ack)
Esempio n. 13
0
    def call(self, method_name, kwargs: dict=None, *, expiration: int = None,
             priority: int = 128, delivery_mode: DeliveryMode = DeliveryMode.NOT_PERSISTENT):

        """ Call remote method and awaiting result.

        :param method_name: Name of method
        :param kwargs: Methos kwargs
        :param expiration: If not `None` messages which staying in queue longer
                           will be returned and :class:`asyncio.TimeoutError` will be raised.
        :param priority: Message priority
        :param delivery_mode: Call message delivery mode
        :raises asyncio.TimeoutError: when message expired
        :raises CancelledError: when called :func:`RPC.cancel`
        :raises RuntimeError: internal error
        """

        future = self.create_future()

        message = Message(
            body=self.serialize(kwargs or {}),
            type='call',
            timestamp=time.time(),
            priority=priority,
            correlation_id=id(future),
            delivery_mode=delivery_mode,
            reply_to=self.result_queue.name,
            headers={
                'From': self.result_queue.name
            }
        )

        if expiration is not None:
            message.expiration = expiration

        yield from self.channel.default_exchange.publish(
            message, routing_key=method_name, mandatory=True
        )

        return (yield from future)
Esempio n. 14
0
    async def on_call_message(self, method_name: str,
                              message: IncomingMessage):
        if method_name not in self.routes:
            log.warning("Method %r not registered in %r", method_name, self)
            return

        try:
            payload = self.deserialize(message.body)
            func = self.routes[method_name]

            result = await self.execute(func, payload)
            result = self.serialize(result)
            message_type = RPCMessageTypes.result.value
        except Exception as e:
            result = self.serialize_exception(e)
            message_type = RPCMessageTypes.error.value

        if not message.reply_to:
            log.info(
                'RPC message without "reply_to" header %r call result '
                "will be lost",
                message,
            )
            await message.ack()
            return

        result_message = Message(
            result,
            content_type=self.CONTENT_TYPE,
            correlation_id=message.correlation_id,
            delivery_mode=message.delivery_mode,
            timestamp=time.time(),
            type=message_type,
        )

        try:
            await self.channel.default_exchange.publish(
                result_message,
                message.reply_to,
                mandatory=False,
            )
        except Exception:
            log.exception("Failed to send reply %r", result_message)
            await message.reject(requeue=False)
            return

        if message_type == RPCMessageTypes.error.value:
            await message.ack()
            return

        await message.ack()
Esempio n. 15
0
    def get_response_message(self,
                             payload=None,
                             headers: dict = None,
                             content_type: str = None,
                             content_encoding: str = None,
                             delivery_mode: DeliveryMode = None,
                             priority: int = None,
                             correlation_id=None,
                             reply_to: str = None,
                             expiration: DateType = None,
                             message_id: str = None,
                             timestamp: DateType = None,
                             type: str = None,
                             user_id: str = None,
                             app_id: str = None) -> Message:
        # body = ''.encode()
        # _content_type = 'text/plain'

        # try:
        body, _content_type = serialize(payload or self.payload)
        # except (PayloadTypeNotSupportedException, SerializeFailedException) as e:
        #     pass

        _headers = self.headers

        if headers is not None:
            if _headers is None:
                _headers = headers
            else:
                _headers.update(headers)

        return Message(body,
                       content_type=content_type or _content_type
                       or self.content_type,
                       content_encoding=content_encoding
                       or self.content_encoding,
                       headers=_headers,
                       delivery_mode=delivery_mode or self.delivery_mode,
                       priority=priority or self.priority,
                       correlation_id=correlation_id or self.correlation_id,
                       reply_to=reply_to or self.reply_to,
                       expiration=expiration or self.expiration,
                       message_id=message_id or self.message_id,
                       timestamp=timestamp or self.timestamp,
                       type=type or self.type,
                       user_id=user_id or self.user_id,
                       app_id=app_id or self.app_id)
async def test_context_process_abrupt_channel_close(
    connection: aio_pika.RobustConnection,
    declare_exchange: Callable,
    declare_queue: Callable,
):
    # https://github.com/mosquito/aio-pika/issues/302
    queue_name = get_random_name("test_connection")
    routing_key = get_random_name("rounting_key")

    channel = await connection.channel()
    exchange = await declare_exchange(
        "direct", auto_delete=True, channel=channel,
    )
    queue = await declare_queue(queue_name, auto_delete=True, channel=channel)

    await queue.bind(exchange, routing_key)
    body = bytes(shortuuid.uuid(), "utf-8")

    await exchange.publish(
        Message(body, content_type="text/plain", headers={"foo": "bar"}),
        routing_key,
    )

    incoming_message = await queue.get(timeout=5)
    # close aiormq channel to emulate abrupt connection/channel close
    await channel.channel.close()
    with pytest.raises(aiormq.exceptions.ChannelInvalidStateError):
        async with incoming_message.process():
            # emulate some activity on closed channel
            await channel.channel.basic_publish(
                "dummy", exchange="", routing_key="non_existent",
            )

    # emulate connection/channel restoration of connect_robust
    await channel.reopen()

    # cleanup queue
    incoming_message = await queue.get(timeout=5)
    async with incoming_message.process():
        pass
    await queue.unbind(exchange, routing_key)
Esempio n. 17
0
    async def request(
        self,
        data: Any,
        *,
        service_name: str = None,
        expiration: int = None,
        delivery_mode: DeliveryMode = DeliveryMode.NOT_PERSISTENT,
    ) -> asyncio.Future:
        """ Send a message to a service and wait for a response.

        :param data: The message data to transfer.

        :param expiration: An optional value representing the number of
          seconds that a message can remain in a queue before being returned
          and a timeout exception (:class:`asyncio.TimeoutError`) is raised.

        :param delivery_mode: Request message delivery mode. Default is
          not-persistent.

        :raises asyncio.TimeoutError: When message expires before being handled.

        :raises CancelledError: when called :func:`RPC.cancel`

        :raises Exception: internal error
        """
        service_name = service_name if service_name else self.service_name

        correlation_id, future = self.create_future()

        headers = {}  # type: Dict[str, str]

        # An exception may be raised here if the message can not be serialized.
        payload, content_type, content_encoding = utils.encode_payload(
            data,
            content_type=self.serialization,
            compression=self.compression,
            headers=headers,
        )

        assert self.response_queue is not None

        # Add a 'From' entry to message headers which will be used to route an
        # expired message to the dead letter exchange queue.
        headers["From"] = self.response_queue.name

        message = Message(
            body=payload,
            content_type=content_type,
            content_encoding=content_encoding,
            timestamp=time.time(),
            correlation_id=correlation_id,
            delivery_mode=delivery_mode,
            reply_to=self.response_queue.name,
            headers=headers,
        )

        if expiration is not None:
            message.expiration = expiration

        logger.debug(
            f"Sending request to {service_name} with correlation_id: {correlation_id}"
        )

        assert self.exchange is not None
        await self.exchange.publish(
            message,
            routing_key=service_name,
            mandatory=True,  # report error if no queues are actively consuming
        )

        logger.debug(f"Waiting for response from {service_name}")
        return await future
Esempio n. 18
0
 async def publish(self, key, data):
     data['key'] = key
     exchange = self.__key_ch[key]  # type: aio_pika.Exchange
     await exchange.publish(Message(json.dumps(data).encode()),
                            routing_key="")
 def get_msg(self):
     include = {"id", "value"}
     body = self.json(include=include, by_alias=True).encode()
     return Message(body=body)
async def test_robust_reconnect(
    create_connection, proxy: TCPProxy, loop, add_cleanup: Callable
):
    conn1 = await create_connection()
    conn2 = await create_connection()

    assert isinstance(conn1, aio_pika.RobustConnection)
    assert isinstance(conn2, aio_pika.RobustConnection)

    async with conn1, conn2:

        channel1 = await conn1.channel()
        channel2 = await conn2.channel()

        assert isinstance(channel1, aio_pika.RobustChannel)
        assert isinstance(channel2, aio_pika.RobustChannel)

        async with channel1, channel2:
            shared = []

            # Declaring temporary queue
            queue = await channel1.declare_queue()

            async def reader(queue_name):
                nonlocal shared
                queue = await channel1.declare_queue(
                    name=queue_name, passive=True,
                )
                async with queue.iterator() as q:
                    async for message in q:
                        shared.append(message)
                        await message.ack()

            reader_task = loop.create_task(reader(queue.name))

            for i in range(5):
                await channel2.default_exchange.publish(
                    Message(str(i).encode()), queue.name,
                )

            logging.info("Disconnect all clients")
            with proxy.slowdown(1, 1):
                task = proxy.disconnect_all()

                with pytest.raises((
                    ConnectionResetError, ConnectionError,
                    aiormq.exceptions.ChannelInvalidStateError
                )):
                    await asyncio.gather(conn1.channel(), conn2.channel())

                await task

            logging.info("Waiting reconnect")
            await asyncio.sleep(conn1.reconnect_interval * 2)

            logging.info("Waiting connections")
            await asyncio.wait_for(
                asyncio.gather(conn1.ready(), conn2.ready()), timeout=20,
            )

            for i in range(5, 10):
                await channel2.default_exchange.publish(
                    Message(str(i).encode()), queue.name,
                )

            while len(shared) < 10:
                await asyncio.sleep(0.1)

            assert len(shared) == 10

            reader_task.cancel()
            await asyncio.gather(reader_task, return_exceptions=True)