コード例 #1
0
    def send_heartbeat(self):
        """Send Heartbeat frame.

        :return:
        """
        if not self._connection.is_open:
            return
        self._write_frame(Heartbeat())
コード例 #2
0
class Connection(Base):
    FRAME_BUFFER = 10
    # Interval between sending heartbeats based on the heartbeat(timeout)
    HEARTBEAT_INTERVAL_MULTIPLIER = 0.5
    # Allow two missed heartbeats (based on heartbeat(timeout)
    HEARTBEAT_GRACE_MULTIPLIER = 3
    _HEARTBEAT = pamqp.frame.marshal(Heartbeat(), 0)

    @staticmethod
    def _parse_ca_data(data) -> typing.Optional[bytes]:
        return b64decode(data) if data else data

    def __init__(self,
                 url: URLorStr,
                 *,
                 parent=None,
                 loop: asyncio.AbstractEventLoop = None):

        super().__init__(loop=loop or asyncio.get_event_loop(), parent=parent)

        self.url = URL(url)

        if self.url.path == "/" or not self.url.path:
            self.vhost = "/"
        else:
            self.vhost = self.url.path[1:]

        self._reader_task = None  # type: asyncio.Task
        self.reader = None  # type: asyncio.StreamReader
        self.writer = None  # type: asyncio.StreamWriter
        self.ssl_certs = SSLCerts(
            cafile=self.url.query.get("cafile"),
            capath=self.url.query.get("capath"),
            cadata=self._parse_ca_data(self.url.query.get("cadata")),
            key=self.url.query.get("keyfile"),
            cert=self.url.query.get("certfile"),
            verify=self.url.query.get("no_verify_ssl", "0") == "0",
        )

        self.started = False
        self.__lock = asyncio.Lock()
        self.__drain_lock = asyncio.Lock()

        self.channels = {}  # type: typing.Dict[int, typing.Optional[Channel]]

        self.server_properties = None  # type: spec.Connection.OpenOk
        self.connection_tune = None  # type: spec.Connection.TuneOk

        self.last_channel = 1

        self.heartbeat_monitoring = parse_bool(
            self.url.query.get("heartbeat_monitoring", "1"))
        self.heartbeat_timeout = parse_int(self.url.query.get(
            "heartbeat", "0"))
        self.heartbeat_last_received = 0
        self.last_channel_lock = asyncio.Lock()
        self.connected = asyncio.Event()
        self.connection_name = self.url.query.get("name")

    @property
    def lock(self):
        if self.is_closed:
            raise RuntimeError("%r closed" % self)

        return self.__lock

    async def drain(self):
        async with self.__drain_lock:
            if not self.writer:
                raise RuntimeError("Writer is %r" % self.writer)
            return await self.writer.drain()

    @property
    def is_opened(self):
        return self.writer is not None and not self.is_closed

    def __str__(self):
        return str(censor_url(self.url))

    def _get_ssl_context(self):
        context = ssl.create_default_context(
            (ssl.Purpose.SERVER_AUTH
             if self.ssl_certs.key else ssl.Purpose.CLIENT_AUTH),
            capath=self.ssl_certs.capath,
            cafile=self.ssl_certs.cafile,
            cadata=self.ssl_certs.cadata,
        )

        if self.ssl_certs.key:
            context.load_cert_chain(self.ssl_certs.cert, self.ssl_certs.key)

        if not self.ssl_certs.verify:
            context.check_hostname = False
            context.verify_mode = ssl.CERT_NONE

        return context

    def _client_properties(self, **kwargs):
        properties = {
            "platform": PLATFORM,
            "version": __version__,
            "product": PRODUCT,
            "capabilities": {
                "authentication_failure_close": True,
                "basic.nack": True,
                "connection.blocked": False,
                "consumer_cancel_notify": True,
                "publisher_confirms": True,
            },
            "information": "See https://github.com/mosquito/aiormq/",
        }

        properties.update(parse_connection_name(self.connection_name))
        properties.update(kwargs.get("client_properties", {}))
        return properties

    @staticmethod
    def _credentials_class(start_frame: spec.Connection.Start):
        for mechanism in start_frame.mechanisms.decode().split():
            with suppress(KeyError):
                return AuthMechanism[mechanism]

        raise exc.AuthenticationError(start_frame.mechanisms,
                                      [m.name for m in AuthMechanism])

    async def __rpc(self, request: spec.Frame, wait_response=True):
        self.writer.write(pamqp.frame.marshal(request, 0))

        if not wait_response:
            return

        _, _, frame = await self.__receive_frame()

        if request.synchronous and frame.name not in request.valid_responses:
            raise spec.AMQPInternalError(frame, dict(frame))
        elif isinstance(frame, spec.Connection.Close):
            if frame.reply_code == 403:
                err = exc.ProbableAuthenticationError(frame.reply_text)
            else:
                err = exc.ConnectionClosed(frame.reply_code, frame.reply_text)

            await self.close(err)

            raise err
        return frame

    @task
    async def connect(self, client_properties: dict = None):
        if self.writer is not None:
            raise RuntimeError("Already connected")

        ssl_context = None

        if self.url.scheme == "amqps":
            ssl_context = await self.loop.run_in_executor(
                None, self._get_ssl_context)

        try:
            self.reader, self.writer = await asyncio.open_connection(
                self.url.host, self.url.port, ssl=ssl_context)
        except OSError as e:
            raise ConnectionError(*e.args) from e

        try:
            protocol_header = ProtocolHeader()
            self.writer.write(protocol_header.marshal())

            res = await self.__receive_frame()
            _, _, frame = res  # type: spec.Connection.Start
            self.heartbeat_last_received = self.loop.time()
        except EOFError as e:
            raise exc.IncompatibleProtocolError(*e.args) from e

        credentials = self._credentials_class(frame)

        self.server_properties = frame.server_properties

        # noinspection PyTypeChecker
        self.connection_tune = await self.__rpc(
            spec.Connection.StartOk(
                client_properties=self._client_properties(
                    **(client_properties or {})),
                mechanism=credentials.name,
                response=credentials.value(self).marshal(),
            ))  # type: spec.Connection.Tune

        if self.heartbeat_timeout > 0:
            self.connection_tune.heartbeat = self.heartbeat_timeout

        await self.__rpc(
            spec.Connection.TuneOk(
                channel_max=self.connection_tune.channel_max,
                frame_max=self.connection_tune.frame_max,
                heartbeat=self.connection_tune.heartbeat,
            ),
            wait_response=False,
        )

        await self.__rpc(spec.Connection.Open(virtual_host=self.vhost))

        # noinspection PyAsyncCall
        self._reader_task = self.create_task(self.__reader())

        # noinspection PyAsyncCall
        heartbeat_task = self.create_task(self.__heartbeat_task())
        heartbeat_task.add_done_callback(self._on_heartbeat_done)
        self.loop.call_soon(self.connected.set)

        return True

    def _on_heartbeat_done(self, future):
        if not future.cancelled() and future.exception():
            self.create_task(
                self.close(ConnectionError("heartbeat task was failed.")))

    async def __heartbeat_task(self):
        if not self.connection_tune.heartbeat:
            return

        heartbeat_interval = (self.connection_tune.heartbeat *
                              self.HEARTBEAT_INTERVAL_MULTIPLIER)
        heartbeat_grace_timeout = (self.connection_tune.heartbeat *
                                   self.HEARTBEAT_GRACE_MULTIPLIER)

        while self.writer:
            # Send heartbeat to server unconditionally
            self.writer.write(self._HEARTBEAT)

            await asyncio.sleep(heartbeat_interval)

            if not self.heartbeat_monitoring:
                continue

            # Check if the server sent us something
            # within the heartbeat grace period
            last_heartbeat = self.loop.time() - self.heartbeat_last_received

            if last_heartbeat <= heartbeat_grace_timeout:
                continue

            await self.close(
                ConnectionError(
                    "Server connection probably hang, last heartbeat "
                    "received %.3f seconds ago" % last_heartbeat))

            return

    async def __receive_frame(self) -> typing.Tuple[int, int, spec.Frame]:
        async with self.lock:
            frame_header = await self.reader.readexactly(1)

            if frame_header == b"\0x00":
                raise spec.AMQPFrameError(await self.reader.read())

            if self.reader is None:
                raise ConnectionError

            frame_header += await self.reader.readexactly(6)

            if not self.started and frame_header.startswith(b"AMQP"):
                raise spec.AMQPSyntaxError
            else:
                self.started = True

            frame_type, _, frame_length = pamqp.frame.frame_parts(frame_header)

            frame_payload = await self.reader.readexactly(frame_length + 1)

        return pamqp.frame.unmarshal(frame_header + frame_payload)

    @staticmethod
    def __exception_by_code(frame: spec.Connection.Close):
        if frame.reply_code == 501:
            return exc.ConnectionFrameError(frame.reply_text)
        elif frame.reply_code == 502:
            return exc.ConnectionSyntaxError(frame.reply_text)
        elif frame.reply_code == 503:
            return exc.ConnectionCommandInvalid(frame.reply_text)
        elif frame.reply_code == 504:
            return exc.ConnectionChannelError(frame.reply_text)
        elif frame.reply_code == 505:
            return exc.ConnectionUnexpectedFrame(frame.reply_text)
        elif frame.reply_code == 506:
            return exc.ConnectionResourceError(frame.reply_text)
        elif frame.reply_code == 530:
            return exc.ConnectionNotAllowed(frame.reply_text)
        elif frame.reply_code == 540:
            return exc.ConnectionNotImplemented(frame.reply_text)
        elif frame.reply_code == 541:
            return exc.ConnectionInternalError(frame.reply_text)
        else:
            return exc.ConnectionClosed(frame.reply_code, frame.reply_text)

    @task
    async def __reader(self):
        try:
            while not self.reader.at_eof():
                weight, channel, frame = await self.__receive_frame()

                self.heartbeat_last_received = self.loop.time()

                if channel == 0:
                    if isinstance(frame, spec.Connection.CloseOk):
                        return
                    if isinstance(frame, spec.Connection.Close):
                        return await self.close(self.__exception_by_code(frame)
                                                )
                    elif isinstance(frame, Heartbeat):
                        continue

                    log.error("Unexpected frame %r", frame)
                    continue

                if self.channels.get(channel) is None:
                    log.exception("Got frame for closed channel %d: %r",
                                  channel, frame)
                    continue

                ch = self.channels[channel]

                if isinstance(frame, CHANNEL_CLOSE_RESPONSES):
                    self.channels[channel] = None

                await ch.frames.put((weight, frame))
        except asyncio.CancelledError as e:
            log.debug("Reader task cancelled:", exc_info=e)
        except asyncio.IncompleteReadError as e:
            log.debug("Can not read bytes from server:", exc_info=e)
            await self.close(ConnectionError(*e.args))
        except Exception as e:
            log.debug("Reader task exited because:", exc_info=e)
            await self.close(e)

    @staticmethod
    async def __close_writer(writer: asyncio.StreamWriter):
        if writer is None:
            return

        writer.close()

        if hasattr(writer, "wait_closed"):
            await writer.wait_closed()

    async def _on_close(self, ex=exc.ConnectionClosed(0, "normal closed")):
        frame = (spec.Connection.CloseOk() if isinstance(
            ex, exc.ConnectionClosed) else spec.Connection.Close())

        await asyncio.gather(self.__rpc(frame, wait_response=False),
                             return_exceptions=True)

        writer = self.writer
        self.reader = None
        self.writer = None
        self._reader_task = None

        await asyncio.gather(self.__close_writer(writer),
                             return_exceptions=True)

        await asyncio.gather(self._reader_task, return_exceptions=True)

    @property
    def server_capabilities(self) -> ArgumentsType:
        return self.server_properties["capabilities"]

    @property
    def basic_nack(self) -> bool:
        return self.server_capabilities.get("basic.nack")

    @property
    def consumer_cancel_notify(self) -> bool:
        return self.server_capabilities.get("consumer_cancel_notify")

    @property
    def exchange_exchange_bindings(self) -> bool:
        return self.server_capabilities.get("exchange_exchange_bindings")

    @property
    def publisher_confirms(self):
        return self.server_capabilities.get("publisher_confirms")

    async def channel(self,
                      channel_number: int = None,
                      publisher_confirms=True,
                      frame_buffer=FRAME_BUFFER,
                      **kwargs) -> Channel:

        await self.connected.wait()

        if self.is_closed:
            raise RuntimeError("%r closed" % self)

        if not self.publisher_confirms and publisher_confirms:
            raise ValueError("Server doesn't support publisher_confirms")

        if channel_number is None:
            async with self.last_channel_lock:
                if self.channels:
                    self.last_channel = max(self.channels.keys())

                while self.last_channel in self.channels.keys():
                    self.last_channel += 1

                    if self.last_channel > 65535:
                        log.warning("Resetting channel number for %r", self)
                        self.last_channel = 1
                        # switching context for prevent blocking event-loop
                        await asyncio.sleep(0)

                channel_number = self.last_channel
        elif channel_number in self.channels:
            raise ValueError("Channel %d already used" % channel_number)

        if channel_number < 0 or channel_number > 65535:
            raise ValueError("Channel number too large")

        channel = Channel(self,
                          channel_number,
                          frame_buffer=frame_buffer,
                          publisher_confirms=publisher_confirms,
                          **kwargs)

        self.channels[channel_number] = channel

        try:
            await channel.open()
        except Exception:
            self.channels[channel_number] = None
            raise

        return channel

    async def __aenter__(self):
        await self.connect()
コード例 #3
0
ファイル: connection.py プロジェクト: caiohsramos/aiormq
class Connection(Base):
    FRAME_BUFFER = 10
    # Interval between sending heartbeats based on the heartbeat(timeout)
    HEARTBEAT_INTERVAL_MULTIPLIER = 0.5
    # Allow two missed heartbeats (based on heartbeat(timeout)
    HEARTBEAT_GRACE_MULTIPLIER = 3
    _HEARTBEAT = pamqp.frame.marshal(Heartbeat(), 0)

    @staticmethod
    def _parse_ca_data(data):
        return b64decode(data) if data else data

    def __init__(self,
                 url: URLorStr,
                 *,
                 parent=None,
                 loop: asyncio.AbstractEventLoop = None):

        super().__init__(loop=loop or asyncio.get_event_loop(), parent=parent)

        self.url = URL(url)

        if self.url.path == '/' or not self.url.path:
            self.vhost = '/'
        else:
            self.vhost = self.url.path[1:]

        self.reader = None  # type: asyncio.StreamReader
        self.writer = None  # type: asyncio.StreamWriter
        self.ssl_certs = SSLCerts(
            cafile=self.url.query.get('cafile'),
            capath=self.url.query.get('capath'),
            cadata=self._parse_ca_data(self.url.query.get('cadata')),
            key=self.url.query.get('keyfile'),
            cert=self.url.query.get('certfile'),
            verify=self.url.query.get('no_verify_ssl', '0') == '0')

        self.started = False
        self.__lock = asyncio.Lock(loop=self.loop)
        self.__drain_lock = asyncio.Lock(loop=self.loop)

        self.channels = {}  # type: typing.Dict[int, typing.Optional[Channel]]

        self.server_properties = None  # type: spec.Connection.OpenOk
        self.connection_tune = None  # type: spec.Connection.TuneOk

        self.last_channel = 0

        self.heartbeat_monitoring = parse_bool(
            self.url.query.get('heartbeat_monitoring', '1'))
        self.heartbeat_timeout = parse_int(self.url.query.get(
            'heartbeat', '0'))
        self.heartbeat_last_received = 0
        self.last_channel_lock = asyncio.Lock(loop=self.loop)
        self.connected = asyncio.Event(loop=self.loop)

    @property
    def lock(self):
        if self.is_closed:
            raise RuntimeError('%r closed' % self)

        return self.__lock

    async def drain(self):
        async with self.__drain_lock:
            return await self.writer.drain()

    @property
    def is_opened(self):
        return self.writer is not None and not self.is_closed

    def __str__(self):
        return str(censor_url(self.url))

    def _get_ssl_context(self):
        context = ssl.create_default_context(
            (ssl.Purpose.SERVER_AUTH
             if self.ssl_certs.key else ssl.Purpose.CLIENT_AUTH),
            capath=self.ssl_certs.capath,
            cafile=self.ssl_certs.cafile,
            cadata=self.ssl_certs.cadata,
        )

        if self.ssl_certs.key:
            context.load_cert_chain(
                self.ssl_certs.cert,
                self.ssl_certs.key,
            )

        if not self.ssl_certs.verify:
            context.check_hostname = False
            context.verify_mode = ssl.CERT_NONE

        return context

    @staticmethod
    def _client_capabilities():
        return {
            'platform': PLATFORM,
            'version': __version__,
            'product': PRODUCT,
            'capabilities': {
                'authentication_failure_close': True,
                'basic.nack': True,
                'connection.blocked': False,
                'consumer_cancel_notify': True,
                'publisher_confirms': True
            },
            'information': 'See https://github.com/mosquito/aiormq/',
        }

    @staticmethod
    def _credentials_class(start_frame: spec.Connection.Start):
        for mechanism in start_frame.mechanisms.decode().split():
            with suppress(KeyError):
                return AuthMechanism[mechanism]

        raise exc.AuthenticationError(start_frame.mechanisms,
                                      [m.name for m in AuthMechanism])

    async def __rpc(self, request: spec.Frame, wait_response=True):
        self.writer.write(pamqp.frame.marshal(request, 0))

        if not wait_response:
            return

        _, _, frame = await self.__receive_frame()

        if request.synchronous and frame.name not in request.valid_responses:
            raise spec.AMQPInternalError(frame, frame)
        elif isinstance(frame, spec.Connection.Close):
            if frame.reply_code == 403:
                raise exc.ProbableAuthenticationError(frame.reply_text)

            raise exc.ConnectionClosed(frame.reply_code, frame.reply_text)
        return frame

    @task
    async def connect(self):
        if self.writer is not None:
            raise RuntimeError("Already connected")

        ssl_context = None

        if self.url.scheme == 'amqps':
            ssl_context = await self.loop.run_in_executor(
                None, self._get_ssl_context)

        try:
            self.reader, self.writer = await asyncio.open_connection(
                self.url.host, self.url.port, ssl=ssl_context, loop=self.loop)
        except OSError as e:
            raise ConnectionError(*e.args) from e

        try:
            protocol_header = ProtocolHeader()
            self.writer.write(protocol_header.marshal())

            res = await self.__receive_frame()
            _, _, frame = res  # type: spec.Connection.Start
            self.heartbeat_last_received = self.loop.time()
        except EOFError as e:
            raise exc.IncompatibleProtocolError(*e.args) from e

        credentials = self._credentials_class(frame)

        self.server_properties = frame.server_properties

        # noinspection PyTypeChecker
        self.connection_tune = await self.__rpc(
            spec.Connection.StartOk(
                client_properties=self._client_capabilities(),
                mechanism=credentials.name,
                response=credentials.value(self).marshal())
        )  # type: spec.Connection.Tune

        if self.heartbeat_timeout > 0:
            self.connection_tune.heartbeat = self.heartbeat_timeout

        await self.__rpc(spec.Connection.TuneOk(
            channel_max=self.connection_tune.channel_max,
            frame_max=self.connection_tune.frame_max,
            heartbeat=self.connection_tune.heartbeat,
        ),
                         wait_response=False)

        await self.__rpc(spec.Connection.Open(virtual_host=self.vhost))

        # noinspection PyAsyncCall
        self.create_task(self.__reader())

        # noinspection PyAsyncCall
        self.create_task(self.__heartbeat_task())
        self.loop.call_soon(self.connected.set)

        return True

    async def __heartbeat_task(self):
        if not self.connection_tune.heartbeat:
            return

        heartbeat_interval = (self.connection_tune.heartbeat *
                              self.HEARTBEAT_INTERVAL_MULTIPLIER)
        heartbeat_grace_timeout = (self.connection_tune.heartbeat *
                                   self.HEARTBEAT_GRACE_MULTIPLIER)

        while True:
            await asyncio.sleep(heartbeat_interval, loop=self.loop)

            # Send heartbeat to server unconditionally
            self.writer.write(self._HEARTBEAT)

            if not self.heartbeat_monitoring:
                continue

            # Check if the server sent us something
            # within the heartbeat grace period
            last_heartbeat = self.loop.time() - self.heartbeat_last_received

            if last_heartbeat <= heartbeat_grace_timeout:
                continue

            await self.close(
                ConnectionError(
                    'Server connection probably hang, last heartbeat '
                    'received %.3f seconds ago' % last_heartbeat))

            return

    @task
    async def __receive_frame(self) -> typing.Tuple[int, int, spec.Frame]:
        async with self.lock:
            frame_header = await self.reader.readexactly(1)

            if frame_header == b'\0x00':
                raise spec.AMQPFrameError(await self.reader.read())

            frame_header += await self.reader.readexactly(6)

            if not self.started and frame_header.startswith(b'AMQP'):
                raise spec.AMQPSyntaxError
            else:
                self.started = True

            frame_type, _, frame_length = pamqp.frame.frame_parts(frame_header)

            frame_payload = await self.reader.readexactly(frame_length + 1)

            return pamqp.frame.unmarshal(frame_header + frame_payload)

    @staticmethod
    def __exception_by_code(frame: spec.Connection.Close):
        if frame.reply_code == 501:
            return exc.ConnectionFrameError(frame.reply_text)
        elif frame.reply_code == 502:
            return exc.ConnectionSyntaxError(frame.reply_text)
        elif frame.reply_code == 503:
            return exc.ConnectionCommandInvalid(frame.reply_text)
        elif frame.reply_code == 504:
            return exc.ConnectionChannelError(frame.reply_text)
        elif frame.reply_code == 505:
            return exc.ConnectionUnexpectedFrame(frame.reply_text)
        elif frame.reply_code == 506:
            return exc.ConnectionResourceError(frame.reply_text)
        elif frame.reply_code == 530:
            return exc.ConnectionNotAllowed(frame.reply_text)
        elif frame.reply_code == 540:
            return exc.ConnectionNotImplemented(frame.reply_text)
        elif frame.reply_code == 541:
            return exc.ConnectionInternalError(frame.reply_text)
        else:
            return exc.ConnectionClosed(frame.reply_code, frame.reply_text)

    async def __reader(self):
        try:
            while not self.reader.at_eof():
                weight, channel, frame = await self.__receive_frame()

                self.heartbeat_last_received = self.loop.time()

                if channel == 0:
                    if isinstance(frame, spec.Connection.Close):
                        return await self.close(self.__exception_by_code(frame)
                                                )
                    elif isinstance(frame, Heartbeat):
                        continue

                    log.error('Unexpected frame %r', frame)
                    continue

                if self.channels.get(channel) is None:
                    log.exception("Got frame for closed channel %d: %r",
                                  channel, frame)
                    continue

                ch = self.channels[channel]

                channel_close_responses = (spec.Channel.Close,
                                           spec.Channel.CloseOk)

                if isinstance(frame, channel_close_responses):
                    self.channels[channel] = None

                await ch.frames.put((weight, frame))
        except asyncio.CancelledError as e:
            log.debug("Reader task cancelled:", exc_info=e)
        except asyncio.IncompleteReadError as e:
            log.debug("Can not read bytes from server:", exc_info=e)
            await self.close(ConnectionError(*e.args))
        except Exception as e:
            log.debug("Reader task exited because:", exc_info=e)
            await self.close(e)

    async def _on_close(self, exc=exc.ConnectionClosed(0, 'normal closed')):
        writer = self.writer
        self.reader = None
        self.writer = None

        # noinspection PyShadowingNames
        writer.close()
        return await writer.wait_closed()

    @property
    def server_capabilities(self) -> ArgumentsType:
        return self.server_properties['capabilities']

    @property
    def basic_nack(self) -> bool:
        return self.server_capabilities.get('basic.nack')

    @property
    def consumer_cancel_notify(self) -> bool:
        return self.server_capabilities.get('consumer_cancel_notify')

    @property
    def exchange_exchange_bindings(self) -> bool:
        return self.server_capabilities.get('exchange_exchange_bindings')

    @property
    def publisher_confirms(self):
        return self.server_capabilities.get('publisher_confirms')

    async def channel(self,
                      channel_number: int = None,
                      publisher_confirms=True,
                      frame_buffer=FRAME_BUFFER,
                      **kwargs) -> Channel:

        await self.connected.wait()

        if self.is_closed:
            raise RuntimeError('%r closed' % self)

        if not self.publisher_confirms and publisher_confirms:
            raise ValueError("Server doesn't support publisher_confirms")

        if channel_number is None:
            async with self.last_channel_lock:
                self.last_channel += 1

                while self.last_channel in self.channels.keys():
                    self.last_channel += 1

                    if self.last_channel > 65535:
                        log.warning("Resetting channel number for %r", self)
                        self.last_channel = 1
                        # switching context for prevent blocking event-loop
                        await asyncio.sleep(0, loop=self.loop)

                channel_number = self.last_channel
        elif channel_number in self.channels:
            raise ValueError("Channel %d already used" % channel_number)

        if channel_number < 0 or channel_number > 65535:
            raise ValueError('Channel number too large')

        channel = Channel(self,
                          channel_number,
                          frame_buffer=frame_buffer,
                          publisher_confirms=publisher_confirms,
                          **kwargs)

        self.channels[channel_number] = channel

        try:
            await channel.open()
        except Exception:
            self.channels[channel_number] = None
            raise

        return channel

    async def __aenter__(self):
        await self.connect()
コード例 #4
0
ファイル: channel0.py プロジェクト: priya-gitTest/wlnupdates
    def send_heartbeat(self):
        """Send Heartbeat frame.

        :return:
        """
        self._write_frame(Heartbeat())
コード例 #5
0
 def test_channel0_heartbeat(self):
     channel = Channel0(self.connection)
     self.assertIsNone(channel.on_frame(Heartbeat()))
コード例 #6
0
class Connection(Base, AbstractConnection):
    FRAME_BUFFER_SIZE = 10
    # Interval between sending heartbeats based on the heartbeat(timeout)
    HEARTBEAT_INTERVAL_MULTIPLIER = 0.5
    # Allow three missed heartbeats (based on heartbeat(timeout)
    HEARTBEAT_GRACE_MULTIPLIER = 3
    _HEARTBEAT = ChannelFrame(
        frames=(Heartbeat(),),
        channel_number=0,
    )

    READER_CLOSE_TIMEOUT = 2

    _reader_task: TaskType
    _writer_task: TaskType
    write_queue: asyncio.Queue
    server_properties: ArgumentsType
    connection_tune: spec.Connection.Tune
    channels: Dict[int, Optional[AbstractChannel]]

    @staticmethod
    def _parse_ca_data(data: Optional[str]) -> Optional[bytes]:
        return b64decode(data) if data else None

    def __init__(
        self,
        url: URLorStr,
        *,
        loop: asyncio.AbstractEventLoop = None,
        context: ssl.SSLContext = None
    ):

        super().__init__(loop=loop or asyncio.get_event_loop(), parent=None)

        self.url = URL(url)
        if self.url.is_absolute() and not self.url.port:
            self.url = self.url.with_port(DEFAULT_PORTS[self.url.scheme])

        if self.url.path == "/" or not self.url.path:
            self.vhost = "/"
        else:
            self.vhost = self.url.path[1:]

        self.ssl_context = context
        self.ssl_certs = SSLCerts(
            cafile=self.url.query.get("cafile"),
            capath=self.url.query.get("capath"),
            cadata=self._parse_ca_data(self.url.query.get("cadata")),
            key=self.url.query.get("keyfile"),
            cert=self.url.query.get("certfile"),
            verify=self.url.query.get("no_verify_ssl", "0") == "0",
        )

        self.started = False
        self.channels = {}
        self.write_queue = asyncio.Queue(
            maxsize=self.FRAME_BUFFER_SIZE,
        )

        self.last_channel = 1

        self.timeout = parse_int(self.url.query.get("timeout", "60"))
        self.heartbeat_timeout = parse_heartbeat(
            self.url.query.get("heartbeat", "60"),
        )
        self.last_channel_lock = asyncio.Lock()
        self.connected = asyncio.Event()
        self.connection_name = self.url.query.get("name")

        self.__close_reply_code: int = REPLY_SUCCESS
        self.__close_reply_text: str = "normally closed"
        self.__close_class_id: int = 0
        self.__close_method_id: int = 0

    async def ready(self) -> None:
        await self.connected.wait()

    def set_close_reason(
        self, reply_code: int = REPLY_SUCCESS,
        reply_text: str = "normally closed",
        class_id: int = 0, method_id: int = 0,
    ) -> None:
        self.__close_reply_code = reply_code
        self.__close_reply_text = reply_text
        self.__close_class_id = class_id
        self.__close_method_id = method_id

    @property
    def is_opened(self) -> bool:
        return not self._writer_task.done() is not None and not self.is_closed

    def __str__(self) -> str:
        return str(censor_url(self.url))

    def _get_ssl_context(self) -> ssl.SSLContext:
        context = ssl.create_default_context(
            ssl.Purpose.SERVER_AUTH,
            capath=self.ssl_certs.capath,
            cafile=self.ssl_certs.cafile,
            cadata=self.ssl_certs.cadata,
        )

        if self.ssl_certs.cert:
            context.load_cert_chain(self.ssl_certs.cert, self.ssl_certs.key)

        if not self.ssl_certs.verify:
            context.check_hostname = False
            context.verify_mode = ssl.CERT_NONE

        return context

    def _client_properties(self, **kwargs: Any) -> Dict[str, Any]:
        properties = {
            "platform": PLATFORM,
            "version": __version__,
            "product": PRODUCT,
            "capabilities": {
                "authentication_failure_close": True,
                "basic.nack": True,
                "connection.blocked": False,
                "consumer_cancel_notify": True,
                "publisher_confirms": True,
            },
            "information": "See https://github.com/mosquito/aiormq/",
        }

        properties.update(
            parse_connection_name(self.connection_name),
        )
        properties.update(kwargs)
        return properties

    def _credentials_class(
        self,
        start_frame: spec.Connection.Start,
    ) -> AuthMechanism:
        auth_requested = self.url.query.get("auth", "plain").upper()
        auth_available = start_frame.mechanisms.split()
        if auth_requested in auth_available:
            with suppress(KeyError):
                return AuthMechanism[auth_requested]
        raise AuthenticationError(
            start_frame.mechanisms, [m.name for m in AuthMechanism],
        )

    @staticmethod
    async def _rpc(
        request: Frame, writer: asyncio.StreamWriter,
        frame_receiver: FrameReceiver,
        wait_response: bool = True
    ) -> Optional[FrameTypes]:
        writer.write(pamqp.frame.marshal(request, 0))

        if not wait_response:
            return None

        _, _, frame = await frame_receiver.get_frame()

        if request.synchronous and frame.name not in request.valid_responses:
            raise AMQPInternalError(
                "one of {!r}".format(request.valid_responses), frame,
            )
        elif isinstance(frame, spec.Connection.Close):
            if frame.reply_code == 403:
                raise ProbableAuthenticationError(frame.reply_text)
            raise ConnectionClosed(frame.reply_code, frame.reply_text)
        return frame

    @task
    async def connect(self, client_properties: dict = None) -> bool:
        if hasattr(self, "_writer_task"):
            raise RuntimeError("Connection already connected")

        ssl_context = self.ssl_context

        if ssl_context is None and self.url.scheme == "amqps":
            ssl_context = await self.loop.run_in_executor(
                None, self._get_ssl_context,
            )

        log.debug("Connecting to: %s", self)
        try:
            reader, writer = await asyncio.open_connection(
                self.url.host, self.url.port, ssl=ssl_context,
            )

            frame_receiver = FrameReceiver(
                reader,
                (self.timeout + 1) * self.HEARTBEAT_GRACE_MULTIPLIER,
            )
        except OSError as e:
            raise ConnectionError(*e.args) from e

        frame: Optional[FrameTypes]

        try:
            protocol_header = ProtocolHeader()
            writer.write(protocol_header.marshal())

            _, _, frame = await frame_receiver.get_frame()
        except EOFError as e:
            raise IncompatibleProtocolError(*e.args) from e

        if not isinstance(frame, spec.Connection.Start):
            raise AMQPInternalError("Connection.StartOk", frame)

        credentials = self._credentials_class(frame)

        server_properties: ArgumentsType = frame.server_properties

        try:
            frame = await self._rpc(
                spec.Connection.StartOk(
                    client_properties=self._client_properties(
                        **(client_properties or {}),
                    ),
                    mechanism=credentials.name,
                    response=credentials.value(self).marshal(),
                ),
                writer=writer,
                frame_receiver=frame_receiver,
            )

            if not isinstance(frame, spec.Connection.Tune):
                raise AMQPInternalError("Connection.Tune", frame)

            connection_tune: spec.Connection.Tune = frame
            connection_tune.heartbeat = self.heartbeat_timeout

            await self._rpc(
                spec.Connection.TuneOk(
                    channel_max=connection_tune.channel_max,
                    frame_max=connection_tune.frame_max,
                    heartbeat=connection_tune.heartbeat,
                ),
                writer=writer,
                frame_receiver=frame_receiver,
                wait_response=False,
            )

            frame = await self._rpc(
                spec.Connection.Open(virtual_host=self.vhost),
                writer=writer,
                frame_receiver=frame_receiver,
            )

            if not isinstance(frame, spec.Connection.OpenOk):
                raise AMQPInternalError("Connection.OpenOk", frame)

            # noinspection PyAsyncCall
            self._reader_task = self.create_task(self.__reader(frame_receiver))
            self._reader_task.add_done_callback(self._on_reader_done)

            # noinspection PyAsyncCall
            self._writer_task = self.create_task(self.__writer(writer))
        except Exception as e:
            await self.close(e)
            raise

        self.connection_tune = connection_tune
        self.server_properties = server_properties
        return True

    def _on_reader_done(self, task: asyncio.Task) -> None:
        log.debug("Reader exited for %r", self)

        if not self._writer_task.done():
            self._writer_task.cancel()

        if not task.cancelled() and task.exception() is not None:
            log.debug("Cancelling cause reader exited abnormally")
            self.set_close_reason(
                reply_code=500, reply_text="reader unexpected closed",
            )
            self.create_task(self.close(task.exception()))

    async def __reader(self, frame_receiver: FrameReceiver) -> None:
        self.connected.set()

        async for weight, channel, frame in frame_receiver:
            log.debug(
                "Received frame %r in channel #%d weight=%s on %r",
                frame, channel, weight, self,
            )

            if channel == 0:
                if isinstance(frame, spec.Connection.CloseOk):
                    return

                if isinstance(frame, spec.Connection.Close):
                    log.exception(
                        "Unexpected connection close from remote \"%s\", "
                        "Connection.Close(reply_code=%r, reply_text=%r)",
                        self, frame.reply_code, frame.reply_text,
                    )

                    self.write_queue.put_nowait(
                        ChannelFrame(
                            channel_number=0,
                            frames=[spec.Connection.CloseOk()],
                        ),
                    )
                    raise exception_by_code(frame)
                elif isinstance(frame, Heartbeat):
                    continue
                elif isinstance(frame, spec.Channel.CloseOk):
                    self.channels.pop(channel, None)

                log.error("Unexpected frame %r", frame)
                continue

            ch: Optional[AbstractChannel] = self.channels.get(channel)
            if ch is None:
                log.error(
                    "Got frame for closed channel %d: %r", channel, frame,
                )
                continue

            if isinstance(frame, CHANNEL_CLOSE_RESPONSES):
                self.channels[channel] = None

            await ch.frames.put((weight, frame))

    async def __frame_iterator(self) -> AsyncIterableType[ChannelFrame]:
        while not self.is_closed:
            try:
                yield await asyncio.wait_for(
                    self.write_queue.get(), timeout=self.timeout,
                )
                self.write_queue.task_done()
            except asyncio.TimeoutError:
                yield self._HEARTBEAT

    async def __writer(self, writer: asyncio.StreamWriter) -> None:
        channel_frame: ChannelFrame

        try:
            async for channel_frame in self.__frame_iterator():
                log.debug("Prepare to send %r", channel_frame)

                frame: FrameTypes

                for frame in channel_frame.frames:
                    log.debug(
                        "Sending frame %r in channel #%d on %r",
                        frame, channel_frame.channel_number, self,
                    )

                    try:
                        writer.write(
                            pamqp.frame.marshal(
                                frame, channel_frame.channel_number,
                            ),
                        )
                    except BaseException as e:
                        log.exception(
                            "Failed to write frame to channel %d: %r",
                            channel_frame.channel_number,
                            frame,
                        )
                        raise asyncio.CancelledError from e

                    if isinstance(frame, spec.Connection.CloseOk):
                        return

                    if (
                        channel_frame.drain_future is not None and
                        not channel_frame.drain_future.done()
                    ):
                        channel_frame.drain_future.set_result(
                            await writer.drain(),
                        )
        except asyncio.CancelledError:
            if not self.__check_writer(writer):
                raise

            frame = spec.Connection.Close(
                reply_code=self.__close_reply_code,
                reply_text=self.__close_reply_text,
                class_id=self.__close_class_id,
                method_id=self.__close_method_id,
            )

            writer.write(pamqp.frame.marshal(frame, 0))
            log.debug("Sending %r to %r", frame, self)

            await writer.drain()
            await self.__close_writer(writer)
            raise
        finally:
            log.debug("Writer exited for %r", self)

    @staticmethod
    async def __close_writer(writer: asyncio.StreamWriter) -> None:
        if writer is None:
            return

        writer.close()

        if hasattr(writer, "wait_closed"):
            await writer.wait_closed()

    @staticmethod
    def __check_writer(writer: asyncio.StreamWriter) -> bool:
        if writer is None:
            return False

        if hasattr(writer, "is_closing"):
            return not writer.is_closing()

        if writer.transport:
            return not writer.transport.is_closing()

        return writer.can_write_eof()

    async def _on_close(
        self,
        ex: Optional[ExceptionType] = ConnectionClosed(0, "normal closed")
    ) -> None:
        log.debug("Closing connection %r cause: %r", self, ex)
        reader_task = self._reader_task
        del self._reader_task

        if not reader_task.done():
            reader_task.cancel()

    @property
    def server_capabilities(self) -> ArgumentsType:
        return self.server_properties["capabilities"]   # type: ignore

    @property
    def basic_nack(self) -> bool:
        return bool(self.server_capabilities.get("basic.nack"))

    @property
    def consumer_cancel_notify(self) -> bool:
        return bool(self.server_capabilities.get("consumer_cancel_notify"))

    @property
    def exchange_exchange_bindings(self) -> bool:
        return bool(self.server_capabilities.get("exchange_exchange_bindings"))

    @property
    def publisher_confirms(self) -> Optional[bool]:
        publisher_confirms = self.server_capabilities.get("publisher_confirms")
        if publisher_confirms is None:
            return None
        return bool(publisher_confirms)

    async def channel(
        self,
        channel_number: int = None,
        publisher_confirms: bool = True,
        frame_buffer_size: int = FRAME_BUFFER_SIZE,
        timeout: TimeoutType = None,
        **kwargs: Any
    ) -> AbstractChannel:

        await self.connected.wait()

        if self.is_closed:
            raise RuntimeError("%r closed" % self)

        if not self.publisher_confirms and publisher_confirms:
            raise ValueError("Server doesn't support publisher_confirms")

        if channel_number is None:
            async with self.last_channel_lock:
                if self.channels:
                    self.last_channel = max(self.channels.keys())

                while self.last_channel in self.channels.keys():
                    self.last_channel += 1

                    if self.last_channel > 65535:
                        log.warning("Resetting channel number for %r", self)
                        self.last_channel = 1
                        # switching context for prevent blocking event-loop
                        await asyncio.sleep(0)

                channel_number = self.last_channel
        elif channel_number in self.channels:
            raise ValueError("Channel %d already used" % channel_number)

        if channel_number < 0 or channel_number > 65535:
            raise ValueError("Channel number too large")

        channel = Channel(
            self,
            channel_number,
            frame_buffer=frame_buffer_size,
            publisher_confirms=publisher_confirms,
            **kwargs,
        )

        self.channels[channel_number] = channel

        try:
            await channel.open(timeout=timeout)
        except Exception:
            self.channels[channel_number] = None
            raise

        return channel

    async def __aenter__(self) -> AbstractConnection:
        await self.connect()
        return self

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> None:
        await self.close(exc_val)