def send_heartbeat(self): """Send Heartbeat frame. :return: """ if not self._connection.is_open: return self._write_frame(Heartbeat())
class Connection(Base): FRAME_BUFFER = 10 # Interval between sending heartbeats based on the heartbeat(timeout) HEARTBEAT_INTERVAL_MULTIPLIER = 0.5 # Allow two missed heartbeats (based on heartbeat(timeout) HEARTBEAT_GRACE_MULTIPLIER = 3 _HEARTBEAT = pamqp.frame.marshal(Heartbeat(), 0) @staticmethod def _parse_ca_data(data) -> typing.Optional[bytes]: return b64decode(data) if data else data def __init__(self, url: URLorStr, *, parent=None, loop: asyncio.AbstractEventLoop = None): super().__init__(loop=loop or asyncio.get_event_loop(), parent=parent) self.url = URL(url) if self.url.path == "/" or not self.url.path: self.vhost = "/" else: self.vhost = self.url.path[1:] self._reader_task = None # type: asyncio.Task self.reader = None # type: asyncio.StreamReader self.writer = None # type: asyncio.StreamWriter self.ssl_certs = SSLCerts( cafile=self.url.query.get("cafile"), capath=self.url.query.get("capath"), cadata=self._parse_ca_data(self.url.query.get("cadata")), key=self.url.query.get("keyfile"), cert=self.url.query.get("certfile"), verify=self.url.query.get("no_verify_ssl", "0") == "0", ) self.started = False self.__lock = asyncio.Lock() self.__drain_lock = asyncio.Lock() self.channels = {} # type: typing.Dict[int, typing.Optional[Channel]] self.server_properties = None # type: spec.Connection.OpenOk self.connection_tune = None # type: spec.Connection.TuneOk self.last_channel = 1 self.heartbeat_monitoring = parse_bool( self.url.query.get("heartbeat_monitoring", "1")) self.heartbeat_timeout = parse_int(self.url.query.get( "heartbeat", "0")) self.heartbeat_last_received = 0 self.last_channel_lock = asyncio.Lock() self.connected = asyncio.Event() self.connection_name = self.url.query.get("name") @property def lock(self): if self.is_closed: raise RuntimeError("%r closed" % self) return self.__lock async def drain(self): async with self.__drain_lock: if not self.writer: raise RuntimeError("Writer is %r" % self.writer) return await self.writer.drain() @property def is_opened(self): return self.writer is not None and not self.is_closed def __str__(self): return str(censor_url(self.url)) def _get_ssl_context(self): context = ssl.create_default_context( (ssl.Purpose.SERVER_AUTH if self.ssl_certs.key else ssl.Purpose.CLIENT_AUTH), capath=self.ssl_certs.capath, cafile=self.ssl_certs.cafile, cadata=self.ssl_certs.cadata, ) if self.ssl_certs.key: context.load_cert_chain(self.ssl_certs.cert, self.ssl_certs.key) if not self.ssl_certs.verify: context.check_hostname = False context.verify_mode = ssl.CERT_NONE return context def _client_properties(self, **kwargs): properties = { "platform": PLATFORM, "version": __version__, "product": PRODUCT, "capabilities": { "authentication_failure_close": True, "basic.nack": True, "connection.blocked": False, "consumer_cancel_notify": True, "publisher_confirms": True, }, "information": "See https://github.com/mosquito/aiormq/", } properties.update(parse_connection_name(self.connection_name)) properties.update(kwargs.get("client_properties", {})) return properties @staticmethod def _credentials_class(start_frame: spec.Connection.Start): for mechanism in start_frame.mechanisms.decode().split(): with suppress(KeyError): return AuthMechanism[mechanism] raise exc.AuthenticationError(start_frame.mechanisms, [m.name for m in AuthMechanism]) async def __rpc(self, request: spec.Frame, wait_response=True): self.writer.write(pamqp.frame.marshal(request, 0)) if not wait_response: return _, _, frame = await self.__receive_frame() if request.synchronous and frame.name not in request.valid_responses: raise spec.AMQPInternalError(frame, dict(frame)) elif isinstance(frame, spec.Connection.Close): if frame.reply_code == 403: err = exc.ProbableAuthenticationError(frame.reply_text) else: err = exc.ConnectionClosed(frame.reply_code, frame.reply_text) await self.close(err) raise err return frame @task async def connect(self, client_properties: dict = None): if self.writer is not None: raise RuntimeError("Already connected") ssl_context = None if self.url.scheme == "amqps": ssl_context = await self.loop.run_in_executor( None, self._get_ssl_context) try: self.reader, self.writer = await asyncio.open_connection( self.url.host, self.url.port, ssl=ssl_context) except OSError as e: raise ConnectionError(*e.args) from e try: protocol_header = ProtocolHeader() self.writer.write(protocol_header.marshal()) res = await self.__receive_frame() _, _, frame = res # type: spec.Connection.Start self.heartbeat_last_received = self.loop.time() except EOFError as e: raise exc.IncompatibleProtocolError(*e.args) from e credentials = self._credentials_class(frame) self.server_properties = frame.server_properties # noinspection PyTypeChecker self.connection_tune = await self.__rpc( spec.Connection.StartOk( client_properties=self._client_properties( **(client_properties or {})), mechanism=credentials.name, response=credentials.value(self).marshal(), )) # type: spec.Connection.Tune if self.heartbeat_timeout > 0: self.connection_tune.heartbeat = self.heartbeat_timeout await self.__rpc( spec.Connection.TuneOk( channel_max=self.connection_tune.channel_max, frame_max=self.connection_tune.frame_max, heartbeat=self.connection_tune.heartbeat, ), wait_response=False, ) await self.__rpc(spec.Connection.Open(virtual_host=self.vhost)) # noinspection PyAsyncCall self._reader_task = self.create_task(self.__reader()) # noinspection PyAsyncCall heartbeat_task = self.create_task(self.__heartbeat_task()) heartbeat_task.add_done_callback(self._on_heartbeat_done) self.loop.call_soon(self.connected.set) return True def _on_heartbeat_done(self, future): if not future.cancelled() and future.exception(): self.create_task( self.close(ConnectionError("heartbeat task was failed."))) async def __heartbeat_task(self): if not self.connection_tune.heartbeat: return heartbeat_interval = (self.connection_tune.heartbeat * self.HEARTBEAT_INTERVAL_MULTIPLIER) heartbeat_grace_timeout = (self.connection_tune.heartbeat * self.HEARTBEAT_GRACE_MULTIPLIER) while self.writer: # Send heartbeat to server unconditionally self.writer.write(self._HEARTBEAT) await asyncio.sleep(heartbeat_interval) if not self.heartbeat_monitoring: continue # Check if the server sent us something # within the heartbeat grace period last_heartbeat = self.loop.time() - self.heartbeat_last_received if last_heartbeat <= heartbeat_grace_timeout: continue await self.close( ConnectionError( "Server connection probably hang, last heartbeat " "received %.3f seconds ago" % last_heartbeat)) return async def __receive_frame(self) -> typing.Tuple[int, int, spec.Frame]: async with self.lock: frame_header = await self.reader.readexactly(1) if frame_header == b"\0x00": raise spec.AMQPFrameError(await self.reader.read()) if self.reader is None: raise ConnectionError frame_header += await self.reader.readexactly(6) if not self.started and frame_header.startswith(b"AMQP"): raise spec.AMQPSyntaxError else: self.started = True frame_type, _, frame_length = pamqp.frame.frame_parts(frame_header) frame_payload = await self.reader.readexactly(frame_length + 1) return pamqp.frame.unmarshal(frame_header + frame_payload) @staticmethod def __exception_by_code(frame: spec.Connection.Close): if frame.reply_code == 501: return exc.ConnectionFrameError(frame.reply_text) elif frame.reply_code == 502: return exc.ConnectionSyntaxError(frame.reply_text) elif frame.reply_code == 503: return exc.ConnectionCommandInvalid(frame.reply_text) elif frame.reply_code == 504: return exc.ConnectionChannelError(frame.reply_text) elif frame.reply_code == 505: return exc.ConnectionUnexpectedFrame(frame.reply_text) elif frame.reply_code == 506: return exc.ConnectionResourceError(frame.reply_text) elif frame.reply_code == 530: return exc.ConnectionNotAllowed(frame.reply_text) elif frame.reply_code == 540: return exc.ConnectionNotImplemented(frame.reply_text) elif frame.reply_code == 541: return exc.ConnectionInternalError(frame.reply_text) else: return exc.ConnectionClosed(frame.reply_code, frame.reply_text) @task async def __reader(self): try: while not self.reader.at_eof(): weight, channel, frame = await self.__receive_frame() self.heartbeat_last_received = self.loop.time() if channel == 0: if isinstance(frame, spec.Connection.CloseOk): return if isinstance(frame, spec.Connection.Close): return await self.close(self.__exception_by_code(frame) ) elif isinstance(frame, Heartbeat): continue log.error("Unexpected frame %r", frame) continue if self.channels.get(channel) is None: log.exception("Got frame for closed channel %d: %r", channel, frame) continue ch = self.channels[channel] if isinstance(frame, CHANNEL_CLOSE_RESPONSES): self.channels[channel] = None await ch.frames.put((weight, frame)) except asyncio.CancelledError as e: log.debug("Reader task cancelled:", exc_info=e) except asyncio.IncompleteReadError as e: log.debug("Can not read bytes from server:", exc_info=e) await self.close(ConnectionError(*e.args)) except Exception as e: log.debug("Reader task exited because:", exc_info=e) await self.close(e) @staticmethod async def __close_writer(writer: asyncio.StreamWriter): if writer is None: return writer.close() if hasattr(writer, "wait_closed"): await writer.wait_closed() async def _on_close(self, ex=exc.ConnectionClosed(0, "normal closed")): frame = (spec.Connection.CloseOk() if isinstance( ex, exc.ConnectionClosed) else spec.Connection.Close()) await asyncio.gather(self.__rpc(frame, wait_response=False), return_exceptions=True) writer = self.writer self.reader = None self.writer = None self._reader_task = None await asyncio.gather(self.__close_writer(writer), return_exceptions=True) await asyncio.gather(self._reader_task, return_exceptions=True) @property def server_capabilities(self) -> ArgumentsType: return self.server_properties["capabilities"] @property def basic_nack(self) -> bool: return self.server_capabilities.get("basic.nack") @property def consumer_cancel_notify(self) -> bool: return self.server_capabilities.get("consumer_cancel_notify") @property def exchange_exchange_bindings(self) -> bool: return self.server_capabilities.get("exchange_exchange_bindings") @property def publisher_confirms(self): return self.server_capabilities.get("publisher_confirms") async def channel(self, channel_number: int = None, publisher_confirms=True, frame_buffer=FRAME_BUFFER, **kwargs) -> Channel: await self.connected.wait() if self.is_closed: raise RuntimeError("%r closed" % self) if not self.publisher_confirms and publisher_confirms: raise ValueError("Server doesn't support publisher_confirms") if channel_number is None: async with self.last_channel_lock: if self.channels: self.last_channel = max(self.channels.keys()) while self.last_channel in self.channels.keys(): self.last_channel += 1 if self.last_channel > 65535: log.warning("Resetting channel number for %r", self) self.last_channel = 1 # switching context for prevent blocking event-loop await asyncio.sleep(0) channel_number = self.last_channel elif channel_number in self.channels: raise ValueError("Channel %d already used" % channel_number) if channel_number < 0 or channel_number > 65535: raise ValueError("Channel number too large") channel = Channel(self, channel_number, frame_buffer=frame_buffer, publisher_confirms=publisher_confirms, **kwargs) self.channels[channel_number] = channel try: await channel.open() except Exception: self.channels[channel_number] = None raise return channel async def __aenter__(self): await self.connect()
class Connection(Base): FRAME_BUFFER = 10 # Interval between sending heartbeats based on the heartbeat(timeout) HEARTBEAT_INTERVAL_MULTIPLIER = 0.5 # Allow two missed heartbeats (based on heartbeat(timeout) HEARTBEAT_GRACE_MULTIPLIER = 3 _HEARTBEAT = pamqp.frame.marshal(Heartbeat(), 0) @staticmethod def _parse_ca_data(data): return b64decode(data) if data else data def __init__(self, url: URLorStr, *, parent=None, loop: asyncio.AbstractEventLoop = None): super().__init__(loop=loop or asyncio.get_event_loop(), parent=parent) self.url = URL(url) if self.url.path == '/' or not self.url.path: self.vhost = '/' else: self.vhost = self.url.path[1:] self.reader = None # type: asyncio.StreamReader self.writer = None # type: asyncio.StreamWriter self.ssl_certs = SSLCerts( cafile=self.url.query.get('cafile'), capath=self.url.query.get('capath'), cadata=self._parse_ca_data(self.url.query.get('cadata')), key=self.url.query.get('keyfile'), cert=self.url.query.get('certfile'), verify=self.url.query.get('no_verify_ssl', '0') == '0') self.started = False self.__lock = asyncio.Lock(loop=self.loop) self.__drain_lock = asyncio.Lock(loop=self.loop) self.channels = {} # type: typing.Dict[int, typing.Optional[Channel]] self.server_properties = None # type: spec.Connection.OpenOk self.connection_tune = None # type: spec.Connection.TuneOk self.last_channel = 0 self.heartbeat_monitoring = parse_bool( self.url.query.get('heartbeat_monitoring', '1')) self.heartbeat_timeout = parse_int(self.url.query.get( 'heartbeat', '0')) self.heartbeat_last_received = 0 self.last_channel_lock = asyncio.Lock(loop=self.loop) self.connected = asyncio.Event(loop=self.loop) @property def lock(self): if self.is_closed: raise RuntimeError('%r closed' % self) return self.__lock async def drain(self): async with self.__drain_lock: return await self.writer.drain() @property def is_opened(self): return self.writer is not None and not self.is_closed def __str__(self): return str(censor_url(self.url)) def _get_ssl_context(self): context = ssl.create_default_context( (ssl.Purpose.SERVER_AUTH if self.ssl_certs.key else ssl.Purpose.CLIENT_AUTH), capath=self.ssl_certs.capath, cafile=self.ssl_certs.cafile, cadata=self.ssl_certs.cadata, ) if self.ssl_certs.key: context.load_cert_chain( self.ssl_certs.cert, self.ssl_certs.key, ) if not self.ssl_certs.verify: context.check_hostname = False context.verify_mode = ssl.CERT_NONE return context @staticmethod def _client_capabilities(): return { 'platform': PLATFORM, 'version': __version__, 'product': PRODUCT, 'capabilities': { 'authentication_failure_close': True, 'basic.nack': True, 'connection.blocked': False, 'consumer_cancel_notify': True, 'publisher_confirms': True }, 'information': 'See https://github.com/mosquito/aiormq/', } @staticmethod def _credentials_class(start_frame: spec.Connection.Start): for mechanism in start_frame.mechanisms.decode().split(): with suppress(KeyError): return AuthMechanism[mechanism] raise exc.AuthenticationError(start_frame.mechanisms, [m.name for m in AuthMechanism]) async def __rpc(self, request: spec.Frame, wait_response=True): self.writer.write(pamqp.frame.marshal(request, 0)) if not wait_response: return _, _, frame = await self.__receive_frame() if request.synchronous and frame.name not in request.valid_responses: raise spec.AMQPInternalError(frame, frame) elif isinstance(frame, spec.Connection.Close): if frame.reply_code == 403: raise exc.ProbableAuthenticationError(frame.reply_text) raise exc.ConnectionClosed(frame.reply_code, frame.reply_text) return frame @task async def connect(self): if self.writer is not None: raise RuntimeError("Already connected") ssl_context = None if self.url.scheme == 'amqps': ssl_context = await self.loop.run_in_executor( None, self._get_ssl_context) try: self.reader, self.writer = await asyncio.open_connection( self.url.host, self.url.port, ssl=ssl_context, loop=self.loop) except OSError as e: raise ConnectionError(*e.args) from e try: protocol_header = ProtocolHeader() self.writer.write(protocol_header.marshal()) res = await self.__receive_frame() _, _, frame = res # type: spec.Connection.Start self.heartbeat_last_received = self.loop.time() except EOFError as e: raise exc.IncompatibleProtocolError(*e.args) from e credentials = self._credentials_class(frame) self.server_properties = frame.server_properties # noinspection PyTypeChecker self.connection_tune = await self.__rpc( spec.Connection.StartOk( client_properties=self._client_capabilities(), mechanism=credentials.name, response=credentials.value(self).marshal()) ) # type: spec.Connection.Tune if self.heartbeat_timeout > 0: self.connection_tune.heartbeat = self.heartbeat_timeout await self.__rpc(spec.Connection.TuneOk( channel_max=self.connection_tune.channel_max, frame_max=self.connection_tune.frame_max, heartbeat=self.connection_tune.heartbeat, ), wait_response=False) await self.__rpc(spec.Connection.Open(virtual_host=self.vhost)) # noinspection PyAsyncCall self.create_task(self.__reader()) # noinspection PyAsyncCall self.create_task(self.__heartbeat_task()) self.loop.call_soon(self.connected.set) return True async def __heartbeat_task(self): if not self.connection_tune.heartbeat: return heartbeat_interval = (self.connection_tune.heartbeat * self.HEARTBEAT_INTERVAL_MULTIPLIER) heartbeat_grace_timeout = (self.connection_tune.heartbeat * self.HEARTBEAT_GRACE_MULTIPLIER) while True: await asyncio.sleep(heartbeat_interval, loop=self.loop) # Send heartbeat to server unconditionally self.writer.write(self._HEARTBEAT) if not self.heartbeat_monitoring: continue # Check if the server sent us something # within the heartbeat grace period last_heartbeat = self.loop.time() - self.heartbeat_last_received if last_heartbeat <= heartbeat_grace_timeout: continue await self.close( ConnectionError( 'Server connection probably hang, last heartbeat ' 'received %.3f seconds ago' % last_heartbeat)) return @task async def __receive_frame(self) -> typing.Tuple[int, int, spec.Frame]: async with self.lock: frame_header = await self.reader.readexactly(1) if frame_header == b'\0x00': raise spec.AMQPFrameError(await self.reader.read()) frame_header += await self.reader.readexactly(6) if not self.started and frame_header.startswith(b'AMQP'): raise spec.AMQPSyntaxError else: self.started = True frame_type, _, frame_length = pamqp.frame.frame_parts(frame_header) frame_payload = await self.reader.readexactly(frame_length + 1) return pamqp.frame.unmarshal(frame_header + frame_payload) @staticmethod def __exception_by_code(frame: spec.Connection.Close): if frame.reply_code == 501: return exc.ConnectionFrameError(frame.reply_text) elif frame.reply_code == 502: return exc.ConnectionSyntaxError(frame.reply_text) elif frame.reply_code == 503: return exc.ConnectionCommandInvalid(frame.reply_text) elif frame.reply_code == 504: return exc.ConnectionChannelError(frame.reply_text) elif frame.reply_code == 505: return exc.ConnectionUnexpectedFrame(frame.reply_text) elif frame.reply_code == 506: return exc.ConnectionResourceError(frame.reply_text) elif frame.reply_code == 530: return exc.ConnectionNotAllowed(frame.reply_text) elif frame.reply_code == 540: return exc.ConnectionNotImplemented(frame.reply_text) elif frame.reply_code == 541: return exc.ConnectionInternalError(frame.reply_text) else: return exc.ConnectionClosed(frame.reply_code, frame.reply_text) async def __reader(self): try: while not self.reader.at_eof(): weight, channel, frame = await self.__receive_frame() self.heartbeat_last_received = self.loop.time() if channel == 0: if isinstance(frame, spec.Connection.Close): return await self.close(self.__exception_by_code(frame) ) elif isinstance(frame, Heartbeat): continue log.error('Unexpected frame %r', frame) continue if self.channels.get(channel) is None: log.exception("Got frame for closed channel %d: %r", channel, frame) continue ch = self.channels[channel] channel_close_responses = (spec.Channel.Close, spec.Channel.CloseOk) if isinstance(frame, channel_close_responses): self.channels[channel] = None await ch.frames.put((weight, frame)) except asyncio.CancelledError as e: log.debug("Reader task cancelled:", exc_info=e) except asyncio.IncompleteReadError as e: log.debug("Can not read bytes from server:", exc_info=e) await self.close(ConnectionError(*e.args)) except Exception as e: log.debug("Reader task exited because:", exc_info=e) await self.close(e) async def _on_close(self, exc=exc.ConnectionClosed(0, 'normal closed')): writer = self.writer self.reader = None self.writer = None # noinspection PyShadowingNames writer.close() return await writer.wait_closed() @property def server_capabilities(self) -> ArgumentsType: return self.server_properties['capabilities'] @property def basic_nack(self) -> bool: return self.server_capabilities.get('basic.nack') @property def consumer_cancel_notify(self) -> bool: return self.server_capabilities.get('consumer_cancel_notify') @property def exchange_exchange_bindings(self) -> bool: return self.server_capabilities.get('exchange_exchange_bindings') @property def publisher_confirms(self): return self.server_capabilities.get('publisher_confirms') async def channel(self, channel_number: int = None, publisher_confirms=True, frame_buffer=FRAME_BUFFER, **kwargs) -> Channel: await self.connected.wait() if self.is_closed: raise RuntimeError('%r closed' % self) if not self.publisher_confirms and publisher_confirms: raise ValueError("Server doesn't support publisher_confirms") if channel_number is None: async with self.last_channel_lock: self.last_channel += 1 while self.last_channel in self.channels.keys(): self.last_channel += 1 if self.last_channel > 65535: log.warning("Resetting channel number for %r", self) self.last_channel = 1 # switching context for prevent blocking event-loop await asyncio.sleep(0, loop=self.loop) channel_number = self.last_channel elif channel_number in self.channels: raise ValueError("Channel %d already used" % channel_number) if channel_number < 0 or channel_number > 65535: raise ValueError('Channel number too large') channel = Channel(self, channel_number, frame_buffer=frame_buffer, publisher_confirms=publisher_confirms, **kwargs) self.channels[channel_number] = channel try: await channel.open() except Exception: self.channels[channel_number] = None raise return channel async def __aenter__(self): await self.connect()
def send_heartbeat(self): """Send Heartbeat frame. :return: """ self._write_frame(Heartbeat())
def test_channel0_heartbeat(self): channel = Channel0(self.connection) self.assertIsNone(channel.on_frame(Heartbeat()))
class Connection(Base, AbstractConnection): FRAME_BUFFER_SIZE = 10 # Interval between sending heartbeats based on the heartbeat(timeout) HEARTBEAT_INTERVAL_MULTIPLIER = 0.5 # Allow three missed heartbeats (based on heartbeat(timeout) HEARTBEAT_GRACE_MULTIPLIER = 3 _HEARTBEAT = ChannelFrame( frames=(Heartbeat(),), channel_number=0, ) READER_CLOSE_TIMEOUT = 2 _reader_task: TaskType _writer_task: TaskType write_queue: asyncio.Queue server_properties: ArgumentsType connection_tune: spec.Connection.Tune channels: Dict[int, Optional[AbstractChannel]] @staticmethod def _parse_ca_data(data: Optional[str]) -> Optional[bytes]: return b64decode(data) if data else None def __init__( self, url: URLorStr, *, loop: asyncio.AbstractEventLoop = None, context: ssl.SSLContext = None ): super().__init__(loop=loop or asyncio.get_event_loop(), parent=None) self.url = URL(url) if self.url.is_absolute() and not self.url.port: self.url = self.url.with_port(DEFAULT_PORTS[self.url.scheme]) if self.url.path == "/" or not self.url.path: self.vhost = "/" else: self.vhost = self.url.path[1:] self.ssl_context = context self.ssl_certs = SSLCerts( cafile=self.url.query.get("cafile"), capath=self.url.query.get("capath"), cadata=self._parse_ca_data(self.url.query.get("cadata")), key=self.url.query.get("keyfile"), cert=self.url.query.get("certfile"), verify=self.url.query.get("no_verify_ssl", "0") == "0", ) self.started = False self.channels = {} self.write_queue = asyncio.Queue( maxsize=self.FRAME_BUFFER_SIZE, ) self.last_channel = 1 self.timeout = parse_int(self.url.query.get("timeout", "60")) self.heartbeat_timeout = parse_heartbeat( self.url.query.get("heartbeat", "60"), ) self.last_channel_lock = asyncio.Lock() self.connected = asyncio.Event() self.connection_name = self.url.query.get("name") self.__close_reply_code: int = REPLY_SUCCESS self.__close_reply_text: str = "normally closed" self.__close_class_id: int = 0 self.__close_method_id: int = 0 async def ready(self) -> None: await self.connected.wait() def set_close_reason( self, reply_code: int = REPLY_SUCCESS, reply_text: str = "normally closed", class_id: int = 0, method_id: int = 0, ) -> None: self.__close_reply_code = reply_code self.__close_reply_text = reply_text self.__close_class_id = class_id self.__close_method_id = method_id @property def is_opened(self) -> bool: return not self._writer_task.done() is not None and not self.is_closed def __str__(self) -> str: return str(censor_url(self.url)) def _get_ssl_context(self) -> ssl.SSLContext: context = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, capath=self.ssl_certs.capath, cafile=self.ssl_certs.cafile, cadata=self.ssl_certs.cadata, ) if self.ssl_certs.cert: context.load_cert_chain(self.ssl_certs.cert, self.ssl_certs.key) if not self.ssl_certs.verify: context.check_hostname = False context.verify_mode = ssl.CERT_NONE return context def _client_properties(self, **kwargs: Any) -> Dict[str, Any]: properties = { "platform": PLATFORM, "version": __version__, "product": PRODUCT, "capabilities": { "authentication_failure_close": True, "basic.nack": True, "connection.blocked": False, "consumer_cancel_notify": True, "publisher_confirms": True, }, "information": "See https://github.com/mosquito/aiormq/", } properties.update( parse_connection_name(self.connection_name), ) properties.update(kwargs) return properties def _credentials_class( self, start_frame: spec.Connection.Start, ) -> AuthMechanism: auth_requested = self.url.query.get("auth", "plain").upper() auth_available = start_frame.mechanisms.split() if auth_requested in auth_available: with suppress(KeyError): return AuthMechanism[auth_requested] raise AuthenticationError( start_frame.mechanisms, [m.name for m in AuthMechanism], ) @staticmethod async def _rpc( request: Frame, writer: asyncio.StreamWriter, frame_receiver: FrameReceiver, wait_response: bool = True ) -> Optional[FrameTypes]: writer.write(pamqp.frame.marshal(request, 0)) if not wait_response: return None _, _, frame = await frame_receiver.get_frame() if request.synchronous and frame.name not in request.valid_responses: raise AMQPInternalError( "one of {!r}".format(request.valid_responses), frame, ) elif isinstance(frame, spec.Connection.Close): if frame.reply_code == 403: raise ProbableAuthenticationError(frame.reply_text) raise ConnectionClosed(frame.reply_code, frame.reply_text) return frame @task async def connect(self, client_properties: dict = None) -> bool: if hasattr(self, "_writer_task"): raise RuntimeError("Connection already connected") ssl_context = self.ssl_context if ssl_context is None and self.url.scheme == "amqps": ssl_context = await self.loop.run_in_executor( None, self._get_ssl_context, ) log.debug("Connecting to: %s", self) try: reader, writer = await asyncio.open_connection( self.url.host, self.url.port, ssl=ssl_context, ) frame_receiver = FrameReceiver( reader, (self.timeout + 1) * self.HEARTBEAT_GRACE_MULTIPLIER, ) except OSError as e: raise ConnectionError(*e.args) from e frame: Optional[FrameTypes] try: protocol_header = ProtocolHeader() writer.write(protocol_header.marshal()) _, _, frame = await frame_receiver.get_frame() except EOFError as e: raise IncompatibleProtocolError(*e.args) from e if not isinstance(frame, spec.Connection.Start): raise AMQPInternalError("Connection.StartOk", frame) credentials = self._credentials_class(frame) server_properties: ArgumentsType = frame.server_properties try: frame = await self._rpc( spec.Connection.StartOk( client_properties=self._client_properties( **(client_properties or {}), ), mechanism=credentials.name, response=credentials.value(self).marshal(), ), writer=writer, frame_receiver=frame_receiver, ) if not isinstance(frame, spec.Connection.Tune): raise AMQPInternalError("Connection.Tune", frame) connection_tune: spec.Connection.Tune = frame connection_tune.heartbeat = self.heartbeat_timeout await self._rpc( spec.Connection.TuneOk( channel_max=connection_tune.channel_max, frame_max=connection_tune.frame_max, heartbeat=connection_tune.heartbeat, ), writer=writer, frame_receiver=frame_receiver, wait_response=False, ) frame = await self._rpc( spec.Connection.Open(virtual_host=self.vhost), writer=writer, frame_receiver=frame_receiver, ) if not isinstance(frame, spec.Connection.OpenOk): raise AMQPInternalError("Connection.OpenOk", frame) # noinspection PyAsyncCall self._reader_task = self.create_task(self.__reader(frame_receiver)) self._reader_task.add_done_callback(self._on_reader_done) # noinspection PyAsyncCall self._writer_task = self.create_task(self.__writer(writer)) except Exception as e: await self.close(e) raise self.connection_tune = connection_tune self.server_properties = server_properties return True def _on_reader_done(self, task: asyncio.Task) -> None: log.debug("Reader exited for %r", self) if not self._writer_task.done(): self._writer_task.cancel() if not task.cancelled() and task.exception() is not None: log.debug("Cancelling cause reader exited abnormally") self.set_close_reason( reply_code=500, reply_text="reader unexpected closed", ) self.create_task(self.close(task.exception())) async def __reader(self, frame_receiver: FrameReceiver) -> None: self.connected.set() async for weight, channel, frame in frame_receiver: log.debug( "Received frame %r in channel #%d weight=%s on %r", frame, channel, weight, self, ) if channel == 0: if isinstance(frame, spec.Connection.CloseOk): return if isinstance(frame, spec.Connection.Close): log.exception( "Unexpected connection close from remote \"%s\", " "Connection.Close(reply_code=%r, reply_text=%r)", self, frame.reply_code, frame.reply_text, ) self.write_queue.put_nowait( ChannelFrame( channel_number=0, frames=[spec.Connection.CloseOk()], ), ) raise exception_by_code(frame) elif isinstance(frame, Heartbeat): continue elif isinstance(frame, spec.Channel.CloseOk): self.channels.pop(channel, None) log.error("Unexpected frame %r", frame) continue ch: Optional[AbstractChannel] = self.channels.get(channel) if ch is None: log.error( "Got frame for closed channel %d: %r", channel, frame, ) continue if isinstance(frame, CHANNEL_CLOSE_RESPONSES): self.channels[channel] = None await ch.frames.put((weight, frame)) async def __frame_iterator(self) -> AsyncIterableType[ChannelFrame]: while not self.is_closed: try: yield await asyncio.wait_for( self.write_queue.get(), timeout=self.timeout, ) self.write_queue.task_done() except asyncio.TimeoutError: yield self._HEARTBEAT async def __writer(self, writer: asyncio.StreamWriter) -> None: channel_frame: ChannelFrame try: async for channel_frame in self.__frame_iterator(): log.debug("Prepare to send %r", channel_frame) frame: FrameTypes for frame in channel_frame.frames: log.debug( "Sending frame %r in channel #%d on %r", frame, channel_frame.channel_number, self, ) try: writer.write( pamqp.frame.marshal( frame, channel_frame.channel_number, ), ) except BaseException as e: log.exception( "Failed to write frame to channel %d: %r", channel_frame.channel_number, frame, ) raise asyncio.CancelledError from e if isinstance(frame, spec.Connection.CloseOk): return if ( channel_frame.drain_future is not None and not channel_frame.drain_future.done() ): channel_frame.drain_future.set_result( await writer.drain(), ) except asyncio.CancelledError: if not self.__check_writer(writer): raise frame = spec.Connection.Close( reply_code=self.__close_reply_code, reply_text=self.__close_reply_text, class_id=self.__close_class_id, method_id=self.__close_method_id, ) writer.write(pamqp.frame.marshal(frame, 0)) log.debug("Sending %r to %r", frame, self) await writer.drain() await self.__close_writer(writer) raise finally: log.debug("Writer exited for %r", self) @staticmethod async def __close_writer(writer: asyncio.StreamWriter) -> None: if writer is None: return writer.close() if hasattr(writer, "wait_closed"): await writer.wait_closed() @staticmethod def __check_writer(writer: asyncio.StreamWriter) -> bool: if writer is None: return False if hasattr(writer, "is_closing"): return not writer.is_closing() if writer.transport: return not writer.transport.is_closing() return writer.can_write_eof() async def _on_close( self, ex: Optional[ExceptionType] = ConnectionClosed(0, "normal closed") ) -> None: log.debug("Closing connection %r cause: %r", self, ex) reader_task = self._reader_task del self._reader_task if not reader_task.done(): reader_task.cancel() @property def server_capabilities(self) -> ArgumentsType: return self.server_properties["capabilities"] # type: ignore @property def basic_nack(self) -> bool: return bool(self.server_capabilities.get("basic.nack")) @property def consumer_cancel_notify(self) -> bool: return bool(self.server_capabilities.get("consumer_cancel_notify")) @property def exchange_exchange_bindings(self) -> bool: return bool(self.server_capabilities.get("exchange_exchange_bindings")) @property def publisher_confirms(self) -> Optional[bool]: publisher_confirms = self.server_capabilities.get("publisher_confirms") if publisher_confirms is None: return None return bool(publisher_confirms) async def channel( self, channel_number: int = None, publisher_confirms: bool = True, frame_buffer_size: int = FRAME_BUFFER_SIZE, timeout: TimeoutType = None, **kwargs: Any ) -> AbstractChannel: await self.connected.wait() if self.is_closed: raise RuntimeError("%r closed" % self) if not self.publisher_confirms and publisher_confirms: raise ValueError("Server doesn't support publisher_confirms") if channel_number is None: async with self.last_channel_lock: if self.channels: self.last_channel = max(self.channels.keys()) while self.last_channel in self.channels.keys(): self.last_channel += 1 if self.last_channel > 65535: log.warning("Resetting channel number for %r", self) self.last_channel = 1 # switching context for prevent blocking event-loop await asyncio.sleep(0) channel_number = self.last_channel elif channel_number in self.channels: raise ValueError("Channel %d already used" % channel_number) if channel_number < 0 or channel_number > 65535: raise ValueError("Channel number too large") channel = Channel( self, channel_number, frame_buffer=frame_buffer_size, publisher_confirms=publisher_confirms, **kwargs, ) self.channels[channel_number] = channel try: await channel.open(timeout=timeout) except Exception: self.channels[channel_number] = None raise return channel async def __aenter__(self) -> AbstractConnection: await self.connect() return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: await self.close(exc_val)