async def _discover(cls, domain: str, session: ClientSession) -> Optional[URL]: well_known = URL.build(scheme="https", host=domain, path="/.well-known/matrix/client") async with session.get(well_known) as resp: if resp.status == 404: return None elif resp.status != 200: raise WellKnownUnexpectedStatus(resp.status) try: data = await resp.json(content_type=None) except (json.JSONDecodeError, ContentTypeError) as e: raise WellKnownNotJSON() from e try: homeserver_url = data["m.homeserver"]["base_url"] except KeyError as e: raise WellKnownMissingHomeserver() from e parsed_url = URL(homeserver_url) if not parsed_url.is_absolute(): raise WellKnownNotURL() elif parsed_url.scheme not in ("http", "https"): raise WellKnownUnsupportedScheme(parsed_url.scheme) try: async with session.get(parsed_url / "_matrix/client/versions") as resp: data = VersionsResponse.deserialize(await resp.json()) assert len(data.versions) > 0 except (ClientError, json.JSONDecodeError, SerializerError, AssertionError) as e: raise WellKnownInvalidVersionsResponse() from e return parsed_url
def make_url(self, path): url = URL(path) if not self.skip_url_asserts: assert not url.is_absolute() return self._root.join(url) else: return URL(str(self._root) + path)
def play_responses(cassette, vcr_request): history = [] vcr_response = cassette.play_response(vcr_request) response = build_response(vcr_request, vcr_response, history) # If we're following redirects, continue playing until we reach # our final destination. while 300 <= response.status <= 399: new_location = response.headers["location"] potential_next_url = URL(new_location) next_url = (potential_next_url if potential_next_url.is_absolute() else URL(response.url).with_path(new_location)) # Make a stub VCR request that we can then use to look up the recorded # VCR request saved to the cassette. This feels a little hacky and # may have edge cases based on the headers we're providing (e.g. if # there's a matcher that is used to filter by headers). vcr_request = Request( "GET", str(next_url), None, _serialize_headers(response.request_info.headers)) vcr_request = cassette.find_requests_with_most_matches( vcr_request)[0][0] # Tack on the response we saw from the redirect into the history # list that is added on to the final response. history.append(response) vcr_response = cassette.play_response(vcr_request) response = build_response(vcr_request, vcr_response, history) return response
def extract(html_body: str, origin_url: str) -> Dict[str, List[str]]: """ Given an html body and its origin URL, extract URLs and categorize them to inbound & outbound. :param html_body: HTML content of the craweld endpoint :type html_body: string :param origin_url: The URL (endpoint) from which the above body originates :type origin_url: string """ inbound = set() outbound = set() origin = URL(origin_url).origin() soup = BeautifulSoup(html_body, "html.parser") for a_element in soup.find_all("a"): try: url = URL(a_element["href"]) except KeyError: continue if url.is_absolute(): url_origin = url.origin() if url_origin != origin: outbound.add(str(url)) continue else: # tel, ftp etc if url.scheme not in ("http", "https", ""): outbound.add(str(url)) continue url_origin = origin inbound.add(url_origin.join(URL(url.path))) return { "inbound": [str(el) for el in inbound], "outbound": [str(el) for el in outbound], }
def download_recipe_images(recipe_path): with open(recipe_path, 'r', encoding='UTF-8') as file: ast = commonmark_parser.parse(file.read()) image_found = False for node, entering in NodeWalker(ast): if entering and node.t == "image": url = URL(node.destination) if url.is_absolute(): image_found = True with urlopen_user_agent(str(url)) as resp: image = resp.read() image_path = os.path.join(root_path, 'images', url.name) with open(image_path, 'wb') as file: file.write(image) relative_image_path = os.path.relpath( image_path, start=os.path.dirname(recipe_path)) node.destination = relative_image_path if image_found: print(f"[Images] Processed {recipe_path}", file=sys.stderr) transformed_source = commonmark_renderer.render(ast) recipe = recipe_parser.parse(transformed_source) with open(recipe_path, 'w', encoding='UTF-8') as file: file.write(recipe_serializer.serialize(recipe))
def _build_url(self, str_or_url: StrOrURL) -> URL: url = URL(str_or_url) if self._base_url is None: return url else: assert not url.is_absolute() and url.path.startswith("/") return self._base_url.join(url)
async def _head(self, sess_id): sess = self._sessions[sess_id] while True: url = self._urls[sess_id] begin = time.monotonic() async with sess.head(url.human_repr()) as resp: if resp.status == 200: length = int(resp.headers.get(aiohttp.hdrs.CONTENT_LENGTH)) delay = time.monotonic() - begin break elif resp.status in REDIRECTS_STATUS and self._ar: r_url = URL( resp.headers.get(aiohttp.hdrs.LOCATION) or resp.headers.get(aiohttp.hdrs.URI)) if r_url.is_absolute(): self._urls[sess_id] = r_url else: self._urls[sess_id] = URL(url).with_path(r_url.path) else: await self._close(sess_id) del self._sessions[sess_id] del self._urls[sess_id] raise HeadStatusError(url.human_repr(), resp.status) return length, (sess_id, delay)
async def request( self, method: str, url: URL, *, params: Optional[Mapping[str, str]] = None, data: Any = None, json: Any = None, headers: Optional[Dict[str, str]] = None, timeout: Optional[aiohttp.ClientTimeout] = None, ) -> AsyncIterator[aiohttp.ClientResponse]: if not url.is_absolute(): url = (self._base_url / "").join(url) log.debug("Fetch [%s] %s", method, url) if timeout is None: timeout = self._timeout if timeout.sock_read is not None: timeout = attr.evolve(timeout, total=3 * 60) async with self._session.request( method, url, headers=headers, params=params, json=json, data=data, timeout=timeout, ) as resp: if 400 <= resp.status: err_text = await resp.text() err_cls = self._exception_map.get(resp.status, IllegalArgumentError) raise err_cls(err_text) else: yield resp
def get_string(tag): r_string = "" for i in tag.children: if isinstance(i, element.Comment): continue elif isinstance(i, element.NavigableString): r_string += re.sub(r"[\n\r]", "", i.string) else: if i.name in ["a"]: href = URL(i.get("href", "").strip()) if not href.is_absolute(): href = base_url.join(href) href_str = "".join([x.strip() for x in i.stripped_strings]) r_string += f"<a href='{href}'>{href_str}</a>" elif i.name in ["br"]: r_string += "\n" else: tag_string = get_string(i) if i.name in ["b", "u"]: r_string += f"<b>{tag_string}</b>\n\n" elif i.name in ["li"]: r_string += f"\n - {tag_string}" elif i.name in ["p"]: r_string += f"\n{tag_string}" else: r_string += tag_string return r_string.strip()
def to_backend_service(rel_url: URL, origin: URL, version_prefix: str) -> URL: """ Translates relative url to backend catalog service url E.g. https://osparc.io/v0/catalog/dags -> http://catalog:8080/v0/dags """ assert not rel_url.is_absolute() # nosec new_path = rel_url.path.replace(f"/{api_version_prefix}/catalog", f"/{version_prefix}") return origin.with_path(new_path).with_query(rel_url.query)
def _clean_urls(v: str) -> List[URL]: if isinstance(v, URL): return [v] if isinstance(v, str): urls = [] for entry in v.split(','): url = URL(entry) if not url.is_absolute(): raise ValueError('URL {} is not absolute.'.format(url)) urls.append(url) return urls
def _download_img(self, url: URL) -> bytes: if not url.is_absolute(): url = self.url.join(url) logger.info(f'downloading image from [{url}]') remote_img = httpx.get(url.human_repr(), headers=header(url)) if remote_img.status_code != 200: remote_img.raise_for_status() return remote_img.content
async def ws_connect(self, abs_url: URL, *, headers: Optional[Dict[str, str]] = None ) -> AsyncIterator[WSMessage]: # TODO: timeout assert abs_url.is_absolute(), abs_url log.debug("Fetch web socket: %s", abs_url) async with self._session.ws_connect(abs_url, headers=headers) as ws: async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: yield msg
async def _build_request( setup: HTTPPollingSetup) -> Optional[HTTPRequest]: url = URL(setup.url) if not url.is_absolute() or not url.host: logger.opt(colors=True).error( f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>" ) return host = f"{url.host}:{url.port}" if url.port else url.host return HTTPRequest(setup.http_version, url.scheme, url.path, url.raw_query_string.encode("latin-1"), { **setup.headers, "host": host }, setup.method, setup.body)
def _ref_generator(self, bs_result_set): for ref in bs_result_set: try: href = URL(ref.attrs['href']) if href.query_string: # Without QS continue if not href.is_absolute(): href = self.url.join(href) if href != self.url: yield href except KeyError: continue
async def ws_connect( self, abs_url: URL, auth: str, *, headers: Optional[Dict[str, str]] = None ) -> AsyncIterator[WSMessage]: # TODO: timeout assert abs_url.is_absolute(), abs_url log.debug("Fetch web socket: %s", abs_url) if headers is not None: real_headers: CIMultiDict[str] = CIMultiDict(headers) else: real_headers = CIMultiDict() real_headers["Authorization"] = auth async with self._session.ws_connect(abs_url, headers=real_headers) as ws: async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: yield msg
def save(self, url: str) -> str: if not isinstance(url, str): url = str(url) ori_url = URL(url) if not ori_url.is_absolute(): ori_url = URL(''.join(['http://', url])) url = str( ori_url.with_path( f'agents/{self.supplier}{ori_url.path}').with_host( 'img%s.weegotr.com' % random.randint(3, 4)).with_scheme('https')) upload_image.apply_async(kwargs={ 'ori_url': str(ori_url), 'url': url }, link_error=on_upload_failure.s()) return url
def upload_images(self, images: List[str], ori_images) -> List[str]: if not images: images = [] _images = images.copy() if images else [] for image_url in images: for ori_image in ori_images: _image_url = URL(image_url) if not _image_url.is_absolute(): _image_url = URL(''.join(['http://', image_url])) if _image_url.path in ori_image and 'weegotr.com' in ori_image: new_url = ori_image break else: new_url = self.save(image_url) _images.remove(image_url) _images.append(new_url) return _images
def prepend_url(base_url: URL, url: URL, encoded: bool = False) -> URL: """Prepend the url. Args: base_url (URL): Base URL to prepend url (URL): url to prepend encoded (bool): Whether to treat the url as already encoded. This may be needed if the url is JavaScript. """ if isinstance(base_url, str): base_url = URL(base_url) if isinstance(url, str): url = URL(url) if not url.is_absolute(): query = url.query path = url.path return base_url.with_path(f"{base_url.path}{path}".replace("//", "/"), encoded=encoded).with_query(query) return url
def is_one_jump_from_original_domain(url: URL, response: Response) -> bool: """ Check that the current URL is only one response away from the originally queried domain. We want to be able to follow potential feed links that point to a different domain than the originally queried domain, but not to follow any deeper than that. Sub-domains of the original domain are ok. i.e: the following are ok "test.com" -> "feedhost.com" "test.com/feeds" -> "example.com/feed.xml" "test.com" -> "feeds.test.com" not ok: "test.com" -> "feedhost.com" (we stop here) -> "feedhost.com/feeds" :param url: URL object or string :param response: Response object :return: boolean """ # This is the first Response in the chain if len(response.history) < 2: return True # The URL is relative, so on the same domain if not url.is_absolute(): return True # URL is same domain if url.host == response.history[0].host: return True # URL is sub-domain if response.history[0].host in url.host: return True # URL domain and current Response domain are different from original domain if (response.history[-1].host != response.history[0].host and url.host != response.history[0].host): return False return True
def read_file_lines(path: Union[str, Path]) -> List[str]: lines = [] with open(path) as file: for line in file.read().splitlines(): line = line.strip() if line == "" or line.startswith("#"): continue proxy_url = URL(line) if not proxy_url.is_absolute(): # TODO: logger.warn(...) proxy_url = URL(f"http://{line}") if proxy_url.scheme not in ALLOWED_PROXY_SCHEME: # TODO: logger.warn(...) continue lines.append(line) return lines
def make_url(self, path): url = URL(path) assert not url.is_absolute() return self._root.join(url)
def test_is_absolute_path_starting_from_double_slash(): url = URL("//www.python.org") assert url.is_absolute()
def test_is_non_absolute_for_empty_url2(): url = URL("") assert not url.is_absolute()
def test_is_absolute_for_absolute_url(): url = URL("http://example.com") assert url.is_absolute()
def test_is_absolute_for_relative_url(): url = URL("/path/to") assert not url.is_absolute()
async def handleRequest(self, request, uid): rUrl = request.requestUrl() rInitiator = request.initiator() rMethod = bytes(request.requestMethod()).decode() host = rUrl.host() path = rUrl.path() if not host: return self.urlInvalid(request) # Build the URL, using the query params if present if rUrl.hasQuery(): q = rUrl.query(QUrl.EncodeSpaces) url = ignition.url(f'{path}?{q}', f'//{host}') else: url = ignition.url(path, f'//{host}') if not rInitiator.isEmpty(): log.debug(f'{rMethod}: {url} (initiator: {rInitiator.toString()})') else: log.debug(f'{rMethod}: {url}') # Run the request in the app's executor response, data = await self.app.loop.run_in_executor( self.app.executor, self.geminiRequest, url) if not response or not data: return self.reqFailed(request) meta = response.meta if isinstance(data, bytes) and meta: # Raw file return self.serveContent(request.reqUid, request, meta, data) if response.is_a(ignition.InputResponse): # Gemini input, serve the form log.debug(f'{rMethod}: {url}: input requested') return await self.serveTemplate(request, 'gemini_input.html', geminput=data, gemurl=url, title=url) elif response.is_a(ignition.RedirectResponse): # Redirects rInfo = data.strip() log.debug(f'{rMethod}: {url}: redirect spec is: {rInfo}') redirUrl = URL(rInfo) if not redirUrl.is_absolute(): # Relative redirect if redirUrl.path.startswith('/'): redirUrl = URL.build(scheme=SCHEME_GEMINI, host=host, path=rInfo) else: redirUrl = URL(f'{url}/{rInfo}') log.debug(f'Gemini ({url}): redirecting to: {redirUrl}') return request.redirect(QUrl(str(redirUrl))) elif response.is_a(ignition.TempFailureResponse): return self.reqFailed(request) elif response.is_a(ignition.PermFailureResponse): return self.reqFailed(request) elif response.is_a(ignition.ClientCertRequiredResponse): return self.reqFailed(request) elif response.is_a(ignition.ErrorResponse): return self.reqFailed(request) try: if not response.success(): raise GeminiError( f'{response.url}: Invalid response: {response.status}') html, title = gemTextToHtml(data) if not html: raise GeminiError(f'{response.url}: gem2html failed') await self.serveTemplate(request, 'gemini_capsule_render.html', gembody=html, gemurl=url, title=title if title else url) except Exception as err: log.debug(f'{rMethod}: {url}: error rendering capsule: {err}') return self.reqFailed(request)
def absurl (s): """ argparse: Absolute URL """ u = URL (s) if u.is_absolute (): return u raise argparse.ArgumentTypeError ('Must be absolute')
async def request( self, method: str, url: URL, *, auth: str, params: Optional[Mapping[str, str]] = None, data: Any = None, json: Any = None, headers: Optional[Dict[str, str]] = None, timeout: Optional[aiohttp.ClientTimeout] = None, ) -> AsyncIterator[aiohttp.ClientResponse]: assert url.is_absolute() log.debug("Fetch [%s] %s", method, url) if headers is not None: real_headers: CIMultiDict[str] = CIMultiDict(headers) else: real_headers = CIMultiDict() real_headers["Authorization"] = auth if "Content-Type" not in real_headers: if json is not None: real_headers["Content-Type"] = "application/json" trace_request_ctx = SimpleNamespace() trace_id = self._trace_id if trace_id is None: trace_id = gen_trace_id() trace_request_ctx.trace_id = trace_id if params: url = url.with_query(params) async with self._session.request( method, url, headers=real_headers, json=json, data=data, timeout=timeout, trace_request_ctx=trace_request_ctx, ) as resp: if 400 <= resp.status: err_text = await resp.text() if resp.content_type.lower() == "application/json": try: payload = jsonmodule.loads(err_text) except ValueError: # One example would be a HEAD request for application/json payload = {} if "error" in payload: err_text = payload["error"] else: payload = {} if resp.status == 400 and "errno" in payload: os_errno: Any = payload["errno"] os_errno = errno.__dict__.get(os_errno, os_errno) raise OSError(os_errno, err_text) err_cls = self._exception_map.get(resp.status, IllegalArgumentError) raise err_cls(err_text) else: try: yield resp except GeneratorExit: # There is a bug in CPython and/or aiohttp, # if GeneratorExit is reraised @asynccontextmanager # reports this as an error # Need to investigate and fix. raise asyncio.CancelledError
class Connection(Base): FRAME_BUFFER = 10 # Interval between sending heartbeats based on the heartbeat(timeout) HEARTBEAT_INTERVAL_MULTIPLIER = 0.5 # Allow two missed heartbeats (based on heartbeat(timeout) HEARTBEAT_GRACE_MULTIPLIER = 3 _HEARTBEAT = pamqp.frame.marshal(Heartbeat(), 0) READER_CLOSE_TIMEOUT = 2 @staticmethod def _parse_ca_data(data) -> typing.Optional[bytes]: return b64decode(data) if data else data def __init__(self, url: URLorStr, *, parent=None, loop: asyncio.AbstractEventLoop = None): super().__init__(loop=loop or asyncio.get_event_loop(), parent=parent) self.url = URL(url) if self.url.is_absolute() and not self.url.port: self.url = self.url.with_port(DEFAULT_PORTS[self.url.scheme]) if self.url.path == "/" or not self.url.path: self.vhost = "/" else: self.vhost = self.url.path[1:] self._reader_task = None # type: asyncio.Task self.reader = None # type: asyncio.StreamReader self.writer = None # type: asyncio.StreamWriter self.ssl_certs = SSLCerts( cafile=self.url.query.get("cafile"), capath=self.url.query.get("capath"), cadata=self._parse_ca_data(self.url.query.get("cadata")), key=self.url.query.get("keyfile"), cert=self.url.query.get("certfile"), verify=self.url.query.get("no_verify_ssl", "0") == "0", ) self.started = False self.__lock = asyncio.Lock() self.__drain_lock = asyncio.Lock() self.channels = {} # type: typing.Dict[int, typing.Optional[Channel]] self.server_properties = None # type: spec.Connection.OpenOk self.connection_tune = None # type: spec.Connection.TuneOk self.last_channel = 1 self.heartbeat_monitoring = parse_bool( self.url.query.get("heartbeat_monitoring", "1"), ) self.heartbeat_timeout = parse_int( self.url.query.get("heartbeat", "0"), ) self.heartbeat_last_received = 0 self.last_channel_lock = asyncio.Lock() self.connected = asyncio.Event() self.connection_name = self.url.query.get("name") @property def lock(self): if self.is_closed: raise RuntimeError("%r closed" % self) return self.__lock async def drain(self): async with self.__drain_lock: if not self.writer: raise RuntimeError("Writer is %r" % self.writer) return await self.writer.drain() @property def is_opened(self): return self.writer is not None and not self.is_closed def __str__(self): return str(censor_url(self.url)) def _get_ssl_context(self): context = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, capath=self.ssl_certs.capath, cafile=self.ssl_certs.cafile, cadata=self.ssl_certs.cadata, ) if self.ssl_certs.cert or self.ssl_certs.key: context.load_cert_chain(self.ssl_certs.cert, self.ssl_certs.key) if not self.ssl_certs.verify: context.check_hostname = False context.verify_mode = ssl.CERT_NONE return context def _client_properties(self, **kwargs): properties = { "platform": PLATFORM, "version": __version__, "product": PRODUCT, "capabilities": { "authentication_failure_close": True, "basic.nack": True, "connection.blocked": False, "consumer_cancel_notify": True, "publisher_confirms": True, }, "information": "See https://github.com/mosquito/aiormq/", } properties.update(parse_connection_name(self.connection_name)) properties.update(kwargs.get("client_properties", {})) return properties @staticmethod def _credentials_class( start_frame: spec.Connection.Start) -> typing.Type[AuthMechanism]: for mechanism in start_frame.mechanisms.split(): with suppress(KeyError): return AuthMechanism[mechanism] raise exc.AuthenticationError( start_frame.mechanisms, [m.name for m in AuthMechanism], ) async def __rpc(self, request: Frame, wait_response=True): self.writer.write(pamqp.frame.marshal(request, 0)) if not wait_response: return _, _, frame = await self.__receive_frame() if request.synchronous and frame.name not in request.valid_responses: raise spec.AMQPInternalError(frame, dict(frame)) elif isinstance(frame, spec.Connection.Close): if frame.reply_code == 403: err = exc.ProbableAuthenticationError(frame.reply_text) else: err = exc.ConnectionClosed(frame.reply_code, frame.reply_text) await self.close(err) raise err return frame @task async def connect(self, client_properties: dict = None): if self.writer is not None: raise RuntimeError("Already connected") ssl_context = None if self.url.scheme == "amqps": ssl_context = await self.loop.run_in_executor( None, self._get_ssl_context, ) try: self.reader, self.writer = await asyncio.open_connection( self.url.host, self.url.port, ssl=ssl_context, ) except OSError as e: raise ConnectionError(*e.args) from e try: protocol_header = ProtocolHeader() self.writer.write(protocol_header.marshal()) res = await self.__receive_frame() _, _, frame = res # type: spec.Connection.Start self.heartbeat_last_received = self.loop.time() except EOFError as e: raise exc.IncompatibleProtocolError(*e.args) from e credentials = self._credentials_class(frame) self.server_properties = frame.server_properties # noinspection PyTypeChecker self.connection_tune = await self.__rpc( spec.Connection.StartOk( client_properties=self._client_properties( **(client_properties or {}), ), mechanism=credentials.name, response=credentials.value(self).marshal(), ), ) # type: spec.Connection.Tune if self.heartbeat_timeout > 0: self.connection_tune.heartbeat = self.heartbeat_timeout await self.__rpc( spec.Connection.TuneOk( channel_max=self.connection_tune.channel_max, frame_max=self.connection_tune.frame_max, heartbeat=self.connection_tune.heartbeat, ), wait_response=False, ) await self.__rpc(spec.Connection.Open(virtual_host=self.vhost)) # noinspection PyAsyncCall self._reader_task = self.create_task(self.__reader()) # noinspection PyAsyncCall heartbeat_task = self.create_task(self.__heartbeat_task()) heartbeat_task.add_done_callback(self._on_heartbeat_done) self.loop.call_soon(self.connected.set) return True def _on_heartbeat_done(self, future): if not future.cancelled() and future.exception(): self.create_task( self.close(ConnectionError("heartbeat task was failed.")), ) async def __heartbeat_task(self): if not self.connection_tune.heartbeat: return heartbeat_interval = (self.connection_tune.heartbeat * self.HEARTBEAT_INTERVAL_MULTIPLIER) heartbeat_grace_timeout = (self.connection_tune.heartbeat * self.HEARTBEAT_GRACE_MULTIPLIER) while self.writer: # Send heartbeat to server unconditionally self.writer.write(self._HEARTBEAT) await asyncio.sleep(heartbeat_interval) if not self.heartbeat_monitoring: continue # Check if the server sent us something # within the heartbeat grace period last_heartbeat = self.loop.time() - self.heartbeat_last_received if last_heartbeat <= heartbeat_grace_timeout: continue await self.close( ConnectionError( "Server connection probably hang, last heartbeat " "received %.3f seconds ago" % last_heartbeat, ), ) return async def __receive_frame(self) -> typing.Tuple[int, int, Frame]: async with self.lock: frame_header = await self.reader.readexactly(1) if frame_header == b"\0x00": raise spec.AMQPFrameError(await self.reader.read()) if self.reader is None: raise ConnectionError frame_header += await self.reader.readexactly(6) if not self.started and frame_header.startswith(b"AMQP"): raise spec.AMQPSyntaxError else: self.started = True frame_type, _, frame_length = pamqp.frame.frame_parts(frame_header) frame_payload = await self.reader.readexactly(frame_length + 1) return pamqp.frame.unmarshal(frame_header + frame_payload) @staticmethod def __exception_by_code(frame: spec.Connection.Close): if frame.reply_code == 501: return exc.ConnectionFrameError(frame.reply_text) elif frame.reply_code == 502: return exc.ConnectionSyntaxError(frame.reply_text) elif frame.reply_code == 503: return exc.ConnectionCommandInvalid(frame.reply_text) elif frame.reply_code == 504: return exc.ConnectionChannelError(frame.reply_text) elif frame.reply_code == 505: return exc.ConnectionUnexpectedFrame(frame.reply_text) elif frame.reply_code == 506: return exc.ConnectionResourceError(frame.reply_text) elif frame.reply_code == 530: return exc.ConnectionNotAllowed(frame.reply_text) elif frame.reply_code == 540: return exc.ConnectionNotImplemented(frame.reply_text) elif frame.reply_code == 541: return exc.ConnectionInternalError(frame.reply_text) else: return exc.ConnectionClosed(frame.reply_code, frame.reply_text) @task async def __reader(self): try: while not self.reader.at_eof(): weight, channel, frame = await self.__receive_frame() self.heartbeat_last_received = self.loop.time() if channel == 0: if isinstance(frame, spec.Connection.CloseOk): return if isinstance(frame, spec.Connection.Close): return await self.close( self.__exception_by_code(frame), ) elif isinstance(frame, Heartbeat): continue elif isinstance(frame, spec.Channel.CloseOk): self.channels.pop(channel, None) log.error("Unexpected frame %r", frame) continue ch = self.channels.get(channel) if ch is None: log.error( "Got frame for closed channel %d: %r", channel, frame, ) continue if isinstance(frame, CHANNEL_CLOSE_RESPONSES): self.channels[channel] = None await ch.frames.put((weight, frame)) except asyncio.CancelledError as e: log.debug("Reader task cancelled:", exc_info=e) except asyncio.IncompleteReadError as e: log.debug("Can not read bytes from server:", exc_info=e) await self.close(ConnectionError(*e.args)) except Exception as e: log.debug("Reader task exited because:", exc_info=e) await self.close(e) @staticmethod async def __close_writer(writer: asyncio.StreamWriter): if writer is None: return writer.close() if hasattr(writer, "wait_closed"): await writer.wait_closed() async def _on_close(self, ex=exc.ConnectionClosed(0, "normal closed")): frame = (spec.Connection.CloseOk() if isinstance( ex, exc.ConnectionClosed) else spec.Connection.Close()) await asyncio.gather( self.__rpc(frame, wait_response=False), return_exceptions=True, ) writer = self.writer self.reader = None self.writer = None reader = self._reader_task self._reader_task = None await asyncio.gather( self.__close_writer(writer), return_exceptions=True, ) if not isinstance(reader, asyncio.Task) or reader.done(): return try: await asyncio.wait_for(asyncio.gather(reader, return_exceptions=True), timeout=self.READER_CLOSE_TIMEOUT) except asyncio.TimeoutError: reader.cancel() await asyncio.gather(reader, return_exceptions=True) @property def server_capabilities(self) -> ArgumentsType: return self.server_properties["capabilities"] @property def basic_nack(self) -> bool: return self.server_capabilities.get("basic.nack") @property def consumer_cancel_notify(self) -> bool: return self.server_capabilities.get("consumer_cancel_notify") @property def exchange_exchange_bindings(self) -> bool: return self.server_capabilities.get("exchange_exchange_bindings") @property def publisher_confirms(self): return self.server_capabilities.get("publisher_confirms") async def channel(self, channel_number: int = None, publisher_confirms=True, frame_buffer=FRAME_BUFFER, **kwargs) -> Channel: await self.connected.wait() if self.is_closed: raise RuntimeError("%r closed" % self) if not self.publisher_confirms and publisher_confirms: raise ValueError("Server doesn't support publisher_confirms") if channel_number is None: async with self.last_channel_lock: if self.channels: self.last_channel = max(self.channels.keys()) while self.last_channel in self.channels.keys(): self.last_channel += 1 if self.last_channel > 65535: log.warning("Resetting channel number for %r", self) self.last_channel = 1 # switching context for prevent blocking event-loop await asyncio.sleep(0) channel_number = self.last_channel elif channel_number in self.channels: raise ValueError("Channel %d already used" % channel_number) if channel_number < 0 or channel_number > 65535: raise ValueError("Channel number too large") channel = Channel( self, channel_number, frame_buffer=frame_buffer, publisher_confirms=publisher_confirms, **kwargs, ) self.channels[channel_number] = channel try: await channel.open() except Exception: self.channels[channel_number] = None raise return channel async def __aenter__(self): await self.connect()