def make_retry_decorator( retries: int, delay: float ) -> typing.Callable[[typing.Callable], typing.Callable]: return retry( wait=wait_fixed(delay), retry=(retry_if_result(lambda res: res.status >= 500) | retry_if_exception_type( exception_types=aiohttp.ClientError)), stop=stop_after_attempt(retries + 1), )
async def test_interactive_services_removed_after_logout( client: TestClient, logged_user: Dict[str, Any], empty_user_project: Dict[str, Any], mocked_director_v2_api: Dict[str, mock.MagicMock], create_dynamic_service_mock, client_session_id_factory: Callable[[], str], socketio_client_factory: Callable, storage_subsystem_mock: MockedStorageSubsystem, # when guest user logs out garbage is collected director_v2_service_mock: aioresponses, expected_save_state: bool, ): # login - logged_user fixture # create empty study - empty_user_project fixture # create dynamic service - create_dynamic_service_mock fixture service = await create_dynamic_service_mock(logged_user["id"], empty_user_project["uuid"]) # create websocket client_session_id1 = client_session_id_factory() sio = await socketio_client_factory(client_session_id1) # open project in client 1 await open_project(client, empty_user_project["uuid"], client_session_id1) # logout logout_url = client.app.router["auth_logout"].url_for() r = await client.post(f"{logout_url}", json={"client_session_id": client_session_id1}) assert r.url_obj.path == logout_url.path await assert_status(r, web.HTTPOk) # check result perfomed by background task await asyncio.sleep(SERVICE_DELETION_DELAY + 1) await garbage_collector_core.collect_garbage(client.app) # assert dynamic service is removed *this is done in a fire/forget way so give a bit of leeway async for attempt in AsyncRetrying(reraise=True, stop=stop_after_attempt(10), wait=wait_fixed(1)): with attempt: logger.warning( "Waiting for stop to have been called service_uuid=%s, save_state=%s", service["service_uuid"], expected_save_state, ) mocked_director_v2_api[ "director_v2_core.stop_service"].assert_awaited_with( app=client.server.app, service_uuid=service["service_uuid"], save_state=expected_save_state, )
async def create( cls, app: FastAPI, settings: DaskSchedulerSettings, endpoint: AnyUrl, authentication: ClusterAuthentication, ) -> "DaskClient": logger.info( "Initiating connection to %s with auth: %s", f"dask-scheduler/gateway at {endpoint}", authentication, ) async for attempt in AsyncRetrying( reraise=True, before_sleep=before_sleep_log(logger, logging.WARNING), wait=wait_fixed(0.3), stop=stop_after_attempt(3), ): with attempt: logger.debug( "Connecting to %s, attempt %s...", endpoint, attempt.retry_state.attempt_number, ) dask_subsystem = await _create_internal_client_based_on_auth( endpoint, authentication) check_scheduler_status(dask_subsystem.client) instance = cls( app=app, dask_subsystem=dask_subsystem, settings=settings, cancellation_dask_pub=distributed.Pub( TaskCancelEvent.topic_name(), client=dask_subsystem.client), ) logger.info( "Connection to %s succeeded [%s]", f"dask-scheduler/gateway at {endpoint}", json.dumps(attempt.retry_state.retry_object.statistics), ) logger.info( "Scheduler info:\n%s", json.dumps(dask_subsystem.client.scheduler_info(), indent=2), ) return instance # this is to satisfy pylance raise ValueError("Could not create client")
async def managed_docker_compose(postgres_volume_name: str, postgres_username: str, postgres_password: str): typer.echo("starting up database in localhost") compose_file = Path.cwd() / "consistency" / "docker-compose.yml" try: subprocess.run( ["docker-compose", "--file", compose_file, "up", "--detach"], shell=False, check=True, cwd=compose_file.parent, env={ **os.environ, **{ "POSTGRES_DATA_VOLUME": postgres_volume_name } }, ) typer.echo( f"database started: adminer available on http://127.0.0.1:18080/?pgsql=postgres&username={postgres_username}&db=simcoredb&ns=public" ) @retry( wait=wait_random(1, 3), stop=stop_after_attempt(10), after=after_log(log, logging.WARN), ) async def postgres_responsive(): async with aiopg.create_pool( f"dbname=simcoredb user={postgres_username} password={postgres_password} host=127.0.0.1" ) as pool: async with pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute("SELECT 1") await postgres_responsive() yield finally: subprocess.run( ["docker-compose", "--file", compose_file, "down"], shell=False, check=True, cwd=compose_file.parent, )
async def assert_service_is_available( # pylint: disable=redefined-outer-name exposed_port: PositiveInt, is_legacy: bool, service_uuid: str) -> None: service_address = (f"http://{get_ip()}:{exposed_port}/x/{service_uuid}" if is_legacy else f"http://{get_ip()}:{exposed_port}") print(f"checking service @ {service_address}") async for attempt in AsyncRetrying(wait=wait_fixed(1), stop=stop_after_attempt(60), reraise=True): with attempt: async with httpx.AsyncClient() as client: response = await client.get(service_address) print( f"{SEPARATOR}\nAttempt={attempt.retry_state.attempt_number}" ) print( f"Body:\n{response.text}\nHeaders={response.headers}\n{SEPARATOR}" ) assert response.status_code == httpx.codes.OK, response.text
def retry_api_call(func, exceptions=('ThrottlingException', 'TooManyRequestsException'), attempt=5, multiplier=1, max_delay=1800, exp_base=2, logger=None, *args, **kwargs): retry = tenacity.Retrying( retry=retry_if_exception(lambda e: getattr(e, 'response', {}).get( 'Error', {}).get('Code', None) in exceptions if e else False), stop=stop_after_attempt(attempt), wait=wait_exponential(multiplier=multiplier, max=max_delay, exp_base=exp_base), after=after_log(logger, logger.level) if logger else None, reraise=True) return retry(func, *args, **kwargs)
async def ensure_volume_cleanup(docker_client: aiodocker.Docker, node_uuid: str) -> None: async def _get_volume_names() -> Set[str]: volumes_list = await docker_client.volumes.list() volume_names: Set[str] = {x["Name"] for x in volumes_list["Volumes"]} return volume_names for volume_name in await _get_volume_names(): if volume_name.startswith(f"dy-sidecar_{node_uuid}"): # docker volume results to be in use and it takes a bit to remove # it once done with it async for attempt in AsyncRetrying( reraise=False, stop=stop_after_attempt(15), wait=wait_fixed(5), ): with attempt: # if volume is still found raise an exception # by the time this finishes all volumes should have been removed if volume_name in await _get_volume_names(): raise _VolumeNotExpectedError(volume_name)
def __init__(self, hl='en-US', tz=360, geo='', timeout=DEFAULT_TIMEOUT_CONFIG, proxies=[], retries=0, backoff_factor=0): """ Initialize default values for params """ # google rate limit self.google_rl = 'You have reached your quota limit. Please try again later.' # set user defined options used globally self.tz = tz self.hl = hl self.geo = geo self.kw_list = list() self.timeout = timeout self.proxies = proxies.copy() # add a proxy option self.blacklisted_proxies = [] self._rate_limited_proxies = [] self.proxy_index = 0 self.cookies = None # intialize widget payloads self.token_payload = dict() self.interest_over_time_widget = dict() self.interest_by_region_widget = dict() self.related_topics_widget_list = list() self.related_queries_widget_list = list() self.backoff_factor = backoff_factor self.retries = retries self._retry_config = dict( wait=wait_exponential(multiplier=self.backoff_factor), stop=stop_after_attempt(self.retries), reraise=True)
def wrapper(f: Any) -> Any: @wraps(f) @retry(stop=stop_after_attempt(5), wait=wait_exponential(2)) async def wrapped(*args: Any, **kwargs: Any) -> Any: attempt_number = wrapped.retry.statistics["attempt_number"] try: result = await f(*args, **kwargs) logger.info("{} initialized", name) return result except Exception as e: max_attempt_number = wrapped.retry.stop.max_attempt_number msg = "{}: initialization failed ({}/{})".format( name, attempt_number, max_attempt_number) if attempt_number < max_attempt_number: msg += ", trying again after {} second.".format( 2**attempt_number) else: msg += "." logger.exception(e) logger.warning(msg) raise e return wrapped
def retry(retries: int, delay: float) -> typing.Callable[[typing.Callable], typing.Callable]: return _retry(wait=wait_fixed(delay), stop=stop_after_attempt(retries + 1), reraise=True)
class ContainerWrapper: client: ContainerClient def __init__(self, container_url: str) -> None: self.client = ContainerClient.from_container_url(container_url) self.container_url = container_url @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def upload_file(self, file_path: str, blob_name: str) -> None: with open(file_path, "rb") as handle: self.client.upload_blob(name=blob_name, data=handle, overwrite=True, max_concurrency=10) return None def upload_file_data(self, data: str, blob_name: str) -> None: self.client.upload_blob(name=blob_name, data=data, overwrite=True, max_concurrency=10) def upload_dir(self, dir_path: str) -> None: # security note: the src for azcopy comes from the server which is # trusted in this context, while the destination is provided by the # user azcopy_sync(dir_path, self.container_url) def download_dir(self, dir_path: str) -> None: # security note: the src for azcopy comes from the server which is # trusted in this context, while the destination is provided by the # user azcopy_sync(self.container_url, dir_path) @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def delete_blob(self, blob_name: str) -> None: self.client.delete_blob(blob_name) return None @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def download_blob(self, blob_name: str) -> bytes: return cast(bytes, self.client.download_blob(blob_name).content_as_bytes()) @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def list_blobs(self, *, name_starts_with: Optional[str] = None) -> List[str]: result = [ x.name for x in self.client.list_blobs(name_starts_with=name_starts_with) ] return cast(List[str], result)
class ContainerWrapper: client: ContainerClient def __init__(self, container_url: str) -> None: self.client = ContainerClient.from_container_url(container_url) self.container_url = container_url @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def upload_file(self, file_path: str, blob_name: str) -> None: try: # Split the container URL to insert the blob_name url_parts = self.container_url.split("?", 1) # Default to azcopy if it is installed azcopy_copy(file_path, url_parts[0] + "/" + blob_name + "?" + url_parts[1]) except Exception as exc: # A subprocess exception would typically only contain the exit status. LOGGER.warning( "Upload using azcopy failed. Check the azcopy logs for more information." ) LOGGER.warning(exc) # Indicate the switch in the approach for clarity in debugging LOGGER.warning("Now attempting to upload using the Python SDK...") # This does not have a try/except since it should be caught by the retry system. # The retry system will always attempt azcopy first and this approach second with open(file_path, "rb") as handle: # Using the Azure SDK default max_concurrency self.client.upload_blob(name=blob_name, data=handle, overwrite=True) return None def upload_file_data(self, data: str, blob_name: str) -> None: with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, blob_name) with open(filename, "w") as handle: handle.write(data) self.upload_file(filename, blob_name) def upload_dir(self, dir_path: str) -> None: # security note: the src for azcopy comes from the server which is # trusted in this context, while the destination is provided by the # user azcopy_sync(dir_path, self.container_url) def download_dir(self, dir_path: str) -> None: # security note: the src for azcopy comes from the server which is # trusted in this context, while the destination is provided by the # user azcopy_sync(self.container_url, dir_path) @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def delete_blob(self, blob_name: str) -> None: self.client.delete_blob(blob_name) return None @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def download_blob(self, blob_name: str) -> bytes: return cast(bytes, self.client.download_blob(blob_name).content_as_bytes()) @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def list_blobs(self, *, name_starts_with: Optional[str] = None) -> List[str]: result = [ x.name for x in self.client.list_blobs(name_starts_with=name_starts_with) ] return cast(List[str], result)
def wait_for_services() -> int: expected_services = core_services() + ops_services() started_services = [] client = docker.from_env() try: for attempt in Retrying( stop=stop_after_attempt(MAX_RETRY_COUNT), wait=wait_fixed(WAIT_BEFORE_RETRY), before_sleep=before_sleep_log(logger, logging.WARNING), ): with attempt: started_services = sorted( [ s for s in client.services.list() if s.name.split("_")[-1] in expected_services ], key=by_service_creation, ) assert len(started_services), "no services started!" assert len(expected_services) == len(started_services), ( "Some services are missing or unexpected:\n" f"expected: {len(expected_services)} {expected_services}\n" f"got: {len(started_services)} {[s.name for s in started_services]}" ) except RetryError: print( f"found these services: {len(started_services)} {[s.name for s in started_services]}\nexpected services: {len(expected_services)} {expected_services}" ) return os.EX_SOFTWARE for service in started_services: expected_replicas = ( service.attrs["Spec"]["Mode"]["Replicated"]["Replicas"] if "Replicated" in service.attrs["Spec"]["Mode"] else len( client.nodes.list()) # we are in global mode ) print(f"Service: {service.name} expects {expected_replicas} replicas", "-" * 10) try: for attempt in Retrying( stop=stop_after_attempt(MAX_RETRY_COUNT), wait=wait_fixed(WAIT_BEFORE_RETRY), ): with attempt: service_tasks: List[Dict] = service.tasks() # freeze print(get_tasks_summary(service_tasks)) # # NOTE: a service could set 'ready' as desired-state instead of 'running' if # it constantly breaks and the swarm desides to "stop trying". # valid_replicas = sum( task["Status"]["State"] == RUNNING_STATE for task in service_tasks) assert valid_replicas == expected_replicas except RetryError: print( f"ERROR: Service {service.name} failed to start {expected_replicas} replica/s" ) print(json.dumps(service.attrs, indent=1)) return os.EX_SOFTWARE return os.EX_OK
class SnowflakeAdapter(BaseSourceAdapter): """The Snowflake Data Warehouse source adapter. Args: preserve_case: By default the adapter folds case-insensitive strings to lowercase. If preserve_case is True,SnowShu will __not__ alter cases (dangerous!). """ name = 'snowflake' SUPPORTS_CROSS_DATABASE = True SUPPORTED_FUNCTIONS = set(['ANY_VALUE', 'RLIKE', 'UUID_STRING']) SUPPORTED_SAMPLE_METHODS = (BernoulliSampleMethod, ) REQUIRED_CREDENTIALS = ( USER, PASSWORD, ACCOUNT, DATABASE, ) ALLOWED_CREDENTIALS = ( SCHEMA, WAREHOUSE, ROLE, ) # snowflake in-db is UPPER, but connector is actually lower :( DEFAULT_CASE = 'lower' DATA_TYPE_MAPPINGS = { "array": dtypes.JSON, "bigint": dtypes.BIGINT, "binary": dtypes.BINARY, "boolean": dtypes.BOOLEAN, "char": dtypes.CHAR, "character": dtypes.CHAR, "date": dtypes.DATE, "datetime": dtypes.DATETIME, "decimal": dtypes.DECIMAL, "double": dtypes.FLOAT, "double precision": dtypes.FLOAT, "float": dtypes.FLOAT, "float4": dtypes.FLOAT, "float8": dtypes.FLOAT, "int": dtypes.BIGINT, "integer": dtypes.BIGINT, "number": dtypes.BIGINT, "numeric": dtypes.NUMERIC, "object": dtypes.JSON, "real": dtypes.FLOAT, "smallint": dtypes.BIGINT, "string": dtypes.VARCHAR, "text": dtypes.VARCHAR, "time": dtypes.TIME, "timestamp": dtypes.TIMESTAMP_NTZ, "timestamp_ntz": dtypes.TIMESTAMP_NTZ, "timestamp_ltz": dtypes.TIMESTAMP_TZ, "timestamp_tz": dtypes.TIMESTAMP_TZ, "varbinary": dtypes.BINARY, "varchar": dtypes.VARCHAR, "variant": dtypes.JSON } MATERIALIZATION_MAPPINGS = {"BASE TABLE": mz.TABLE, "VIEW": mz.VIEW} @overrides def _get_all_databases(self) -> List[str]: """ Use the SHOW api to get all the available db structures.""" logger.debug('Collecting databases from snowflake...') show_result = tuple( self._safe_query("SHOW TERSE DATABASES")['name'].tolist()) databases = list(set(show_result)) logger.debug(f'Done. Found {len(databases)} databases.') return databases @overrides def _get_all_schemas( self, database: str, exclude_defaults: Optional[bool] = False) -> List[str]: logger.debug(f'Collecting schemas from {database} in snowflake...') show_result = self._safe_query( f'SHOW TERSE SCHEMAS IN DATABASE {database}')['name'].tolist() schemas = set(show_result) logger.debug( f'Done. Found {len(schemas)} schemas in {database} database.') return schemas @staticmethod def population_count_statement(relation: Relation) -> str: """creates the count * statement for a relation Args: relation: the :class:`Relation <snowshu.core.models.relation.Relation>` to create the statement for. Returns: a query that results in a single row, single column, integer value of the unsampled relation population size """ return f"SELECT COUNT(*) FROM {relation.quoted_dot_notation}" @staticmethod def view_creation_statement(relation: Relation) -> str: return f""" SELECT SUBSTRING(GET_DDL('view','{relation.quoted_dot_notation}'), POSITION(' AS ' IN UPPER(GET_DDL('view','{relation.quoted_dot_notation}')))+3) """ @staticmethod def unsampled_statement(relation: Relation) -> str: return f""" SELECT * FROM {relation.quoted_dot_notation} """ def directionally_wrap_statement( self, sql: str, relation: Relation, sample_type: Optional['BaseSampleMethod']) -> str: if sample_type is None: return sql return f""" WITH {relation.scoped_cte('SNOWSHU_FINAL_SAMPLE')} AS ( {sql} ) ,{relation.scoped_cte('SNOWSHU_DIRECTIONAL_SAMPLE')} AS ( SELECT * FROM {relation.scoped_cte('SNOWSHU_FINAL_SAMPLE')} {self._sample_type_to_query_sql(sample_type)} ) SELECT * FROM {relation.scoped_cte('SNOWSHU_DIRECTIONAL_SAMPLE')} """ @staticmethod def analyze_wrap_statement(sql: str, relation: Relation) -> str: return f""" WITH {relation.scoped_cte('SNOWSHU_COUNT_POPULATION')} AS ( SELECT COUNT(*) AS population_size FROM {relation.quoted_dot_notation} ) ,{relation.scoped_cte('SNOWSHU_CORE_SAMPLE')} AS ( {sql} ) ,{relation.scoped_cte('SNOWSHU_CORE_SAMPLE_COUNT')} AS ( SELECT COUNT(*) AS sample_size FROM {relation.scoped_cte('SNOWSHU_CORE_SAMPLE')} ) SELECT s.sample_size AS sample_size ,p.population_size AS population_size FROM {relation.scoped_cte('SNOWSHU_CORE_SAMPLE_COUNT')} s INNER JOIN {relation.scoped_cte('SNOWSHU_COUNT_POPULATION')} p ON 1=1 LIMIT 1 """ def sample_statement_from_relation( self, relation: Relation, sample_type: Union['BaseSampleMethod', None]) -> str: """builds the base sample statment for a given relation.""" query = f""" SELECT * FROM {relation.quoted_dot_notation} """ if sample_type is not None: query += f"{self._sample_type_to_query_sql(sample_type)}" return query @staticmethod def union_constraint_statement(subject: Relation, constraint: Relation, subject_key: str, constraint_key: str, max_number_of_outliers: int) -> str: """ Union statements to select outliers. This does not pull in NULL values. """ return f""" (SELECT * FROM {subject.quoted_dot_notation} WHERE {subject_key} NOT IN (SELECT {constraint_key} FROM {constraint.quoted_dot_notation}) LIMIT {max_number_of_outliers}) """ @staticmethod def upstream_constraint_statement(relation: Relation, local_key: str, remote_key: str) -> str: """ builds upstream where constraints against downstream full population""" return f" {local_key} in (SELECT {remote_key} FROM {relation.quoted_dot_notation})" @staticmethod def predicate_constraint_statement(relation: Relation, analyze: bool, local_key: str, remote_key: str) -> str: """builds 'where' strings""" constraint_sql = str() if analyze: constraint_sql = f" SELECT {remote_key} AS {local_key} FROM ({relation.core_query})" else: def quoted(val: Any) -> str: return f"'{val}'" if relation.lookup_attribute( remote_key).data_type.requires_quotes else str(val) try: constraint_set = [ quoted(val) for val in relation.data[remote_key].unique() ] constraint_sql = ','.join(constraint_set) except KeyError as err: logger.critical( f'failed to build predicates for {relation.dot_notation}: ' f'remote key {remote_key} not in dataframe columns ({relation.data.columns})' ) raise err return f"{local_key} IN ({constraint_sql}) " @staticmethod def polymorphic_constraint_statement( relation: Relation, # noqa pylint: disable=too-many-arguments analyze: bool, local_key: str, remote_key: str, local_type: str, local_type_match_val: str = None) -> str: predicate = SnowflakeAdapter.predicate_constraint_statement( relation, analyze, local_key, remote_key) if local_type_match_val: type_match_val = local_type_match_val else: type_match_val = relation.name[:-1] if relation.name[-1].lower( ) == 's' else relation.name return f" ({predicate} AND LOWER({local_type}) = LOWER('{type_match_val}') ) " @staticmethod def _sample_type_to_query_sql(sample_type: 'BaseSampleMethod') -> str: if sample_type.name == 'BERNOULLI': qualifier = sample_type.probability if sample_type.probability\ else str(sample_type.rows) + ' ROWS' return f"SAMPLE BERNOULLI ({qualifier})" if sample_type.name == 'SYSTEM': return f"SAMPLE SYSTEM ({sample_type.probability})" message = f"{sample_type.name} is not supported for SnowflakeAdapter" logger.error(message) raise NotImplementedError(message) # TODO: change arg name in parent to the fix issue here @overrides def _build_conn_string(self, overrides: Optional[dict] = None) -> str: # noqa pylint: disable=redefined-outer-name """overrides the base conn string.""" conn_parts = [ f"snowflake://{self.credentials.user}:{self.credentials.password}" f"@{self.credentials.account}/{self.credentials.database}/" ] conn_parts.append(self.credentials.schema if self.credentials. schema is not None else '') get_args = list() for arg in ( 'warehouse', 'role', ): if self.credentials.__dict__[arg] is not None: get_args.append(f"{arg}={self.credentials.__dict__[arg]}") get_string = "?" + "&".join(get_args) return (''.join(conn_parts)) + get_string @overrides def _get_relations_from_database( self, schema_obj: BaseSourceAdapter._DatabaseObject) -> List[Relation]: quoted_database = schema_obj.full_relation.quoted( schema_obj.full_relation.database) # quoted db name relation_database = schema_obj.full_relation.database # case corrected db name case_sensitive_schema = schema_obj.case_sensitive_name # case sensitive schame name relations_sql = f""" SELECT m.table_schema AS schema, m.table_name AS relation, m.table_type AS materialization, c.column_name AS attribute, c.ordinal_position AS ordinal, c.data_type AS data_type FROM {quoted_database}.INFORMATION_SCHEMA.TABLES m INNER JOIN {quoted_database}.INFORMATION_SCHEMA.COLUMNS c ON c.table_schema = m.table_schema AND c.table_name = m.table_name WHERE m.table_schema = '{case_sensitive_schema}' AND m.table_schema <> 'INFORMATION_SCHEMA' """ logger.debug( f'Collecting detailed relations from database {quoted_database}...' ) relations_frame = self._safe_query(relations_sql) unique_relations = (relations_frame['schema'] + '.' + relations_frame['relation']).unique().tolist() logger.debug( f'Done collecting relations. Found a total of {len(unique_relations)} ' f'unique relations in database {quoted_database}') relations = list() for relation in unique_relations: logger.debug( f'Building relation { quoted_database + "." + relation }...') attributes = list() for attribute in relations_frame.loc[( relations_frame['schema'] + '.' + relations_frame['relation']) == relation].itertuples(): logger.debug( f'adding attribute {attribute.attribute} to relation..') attributes.append( Attribute(self._correct_case(attribute.attribute), self._get_data_type(attribute.data_type))) relation = Relation( relation_database, self._correct_case(attribute.schema), # noqa pylint: disable=undefined-loop-variable self._correct_case(attribute.relation), # noqa pylint: disable=undefined-loop-variable self.MATERIALIZATION_MAPPINGS[attribute.materialization], # noqa pylint: disable=undefined-loop-variable attributes) logger.debug(f'Added relation {relation.dot_notation} to pool.') relations.append(relation) logger.debug( f'Acquired {len(relations)} total relations from database {quoted_database}.' ) return relations @overrides def _count_query(self, query: str) -> int: count_sql = f"WITH __SNOWSHU__COUNTABLE__QUERY as ({query}) \ SELECT COUNT(*) AS count FROM __SNOWSHU__COUNTABLE__QUERY" count = int(self._safe_query(count_sql).iloc[0]['count']) return count @tenacity.retry(wait=wait_exponential(), stop=stop_after_attempt(4), before_sleep=Logger().log_retries, reraise=True) @overrides def check_count_and_query(self, query: str, max_count: int, unsampled: bool) -> pd.DataFrame: """checks the count, if count passes returns results as a dataframe.""" try: logger.debug('Checking count for query...') start_time = time.time() count = self._count_query(query) if unsampled and count > max_count: warn_msg = ( f'Unsampled relation has {count} rows which is over ' f'the max allowed rows for this type of query ({max_count}). ' f'All records will be loaded into replica.') logger.warning(warn_msg) else: assert count <= max_count logger.debug( f'Query count safe at {count} rows in {time.time()-start_time} seconds.' ) except AssertionError: message = ( f'failed to execute query, result would have returned {count} rows ' f'but the max allowed rows for this type of query is {max_count}.' ) logger.error(message) logger.debug(f'failed sql: {query}') raise TooManyRecords(message) response = self._safe_query(query) return response @overrides def get_connection( self, database_override: Optional[str] = None, schema_override: Optional[str] = None ) -> sqlalchemy.engine.base.Engine: """Creates a connection engine without transactions. By default uses the instance credentials unless database or schema override are provided. """ if not self._credentials: raise KeyError( 'Adapter.get_connection called before setting Adapter.credentials' ) logger.debug(f'Aquiring {self.CLASSNAME} connection...') overrides = dict( # noqa pylint: disable=redefined-outer-name (k, v) for (k, v) in dict(database=database_override, schema=schema_override).items() if v is not None) engine = sqlalchemy.create_engine(self._build_conn_string(overrides), poolclass=NullPool) logger.debug(f'engine aquired. Conn string: {repr(engine.url)}') return engine
'retmax': 10, 'idtype': 'acc', 'usehistory': 'y' } r = requests.post(esearch, data=search_params) search_results = r.json() result_count = int(search_results['esearchresult']['count']) query_key = search_results['esearchresult']['querykey'] web_env = search_results['esearchresult']['webenv'] print('Search for {} had {} results.'.format(SEARCH_TERM, result_count)) @retry(stop=stop_after_attempt(7), wait=wait_fixed(2)) def parse_fasta_xml(query_key, web_env, retstart, batch_size): fetch_url = 'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi' params = { 'db': 'nuccore', 'rettype': 'fasta', 'retmode': 'xml', 'query_key': query_key, 'WebEnv': web_env, 'retstart': retstart, 'retmax': batch_size } r = requests.post(fetch_url, data=params) try: results = [] xml_results = objectify.fromstring(r.content)
class PolicyFetcher: """ fetches policy from backend """ DEFAULT_RETRY_CONFIG = { 'wait': wait.wait_random_exponential(max=10), 'stop': stop.stop_after_attempt(5), 'reraise': True, } def __init__(self, backend_url=None, token=None, retry_config=None): """ Args: backend_url ([type], optional): Defaults to opal_client_config.SERVER_URL. token ([type], optional): [description]. Defaults to opal_client_config.CLIENT_TOKEN. """ self._backend_url = backend_url or opal_client_config.SERVER_URL self._token = token or opal_client_config.CLIENT_TOKEN self._auth_headers = tuple_to_dict( get_authorization_header(self._token)) self._retry_config = retry_config if retry_config is not None else self.DEFAULT_RETRY_CONFIG async def fetch_policy_bundle( self, directories: List[str] = ['.'], base_hash: Optional[str] = None) -> Optional[PolicyBundle]: attempter = retry(**self._retry_config)(self._fetch_policy_bundle) try: return await attempter(directories=directories, base_hash=base_hash) except Exception as err: logger.warning( "Failed all attempts to fetch bundle, got error: {err}", err=repr(err)) return None async def _fetch_policy_bundle( self, directories: List[str] = ['.'], base_hash: Optional[str] = None) -> Optional[PolicyBundle]: """ Fetches the bundle. May throw, in which case we retry again. """ params = {"path": directories} if base_hash is not None: params["base_hash"] = base_hash async with aiohttp.ClientSession() as session: try: async with session.get(f"{self._backend_url}/policy", headers={ 'content-type': 'text/plain', **self._auth_headers }, params=params) as response: if response.status == status.HTTP_404_NOT_FOUND: logger.warning("requested paths not found: {paths}", paths=directories) return None # may throw ValueError await throw_if_bad_status_code( response, expected=[status.HTTP_200_OK]) # may throw Validation Error bundle = await response.json() return force_valid_bundle(bundle) except aiohttp.ClientError as e: logger.warning("server connection error: {err}", err=repr(e)) raise
class GitLabProvider(Provider): def __init__(self, args, username, token, serverurl, project, mrnumber): super().__init__(args, username, token, serverurl) self.__project = int(project) self.__mrnum = int(mrnumber) self.__session = requests.Session() self.__session.verify = False self.__mr = self.__get_mergerequest(self.__get_project()) # get diffs self.__mr_changes = self.__mr.changes(all=True) for change in self.__mr_changes['changes']: self._changes.AddChange(change['new_path'], change['diff'], change['new_file']) self.__isvalid = self.__mr.state not in ['merged', 'closed'] self.__isvalid &= not self.__is_draft() or self.AllowDrafts @property def Valid(self): return self.__isvalid def __is_draft(self): return self.__mr.work_in_progress @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def __get_server_connection(self): try: return gitlab.Gitlab(self.ServerURL, private_token=self.Token, session=self.__session) except Exception as e: logging.error('GitLab connections failed') raise SCABotServerCommError(e) @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def __get_project(self): try: return self.__get_server_connection().projects.get(self.__project) except Exception: logging.error('Project not found') raise SCABotProjectNotFoundError(self.__project) @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def __get_mergerequest(self, proj): try: return proj.mergerequests.get(self.__mrnum) except Exception: logging.error('MR not found') raise SCABotRequestNotFoundError(self.__mrnum) @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def SetNote(self, value: Note): _obj = { 'body': value.body, 'position': { 'base_sha': self.__mr_changes['diff_refs']['base_sha'], 'start_sha': self.__mr_changes['diff_refs']['start_sha'], 'head_sha': self.__mr_changes['diff_refs']['head_sha'], 'position_type': 'text', 'old_line': value.lines[1] if len(value.lines) > 1 else None, 'new_line': value.lines[0], 'new_path': value.path, }, 'author': { 'username': self.Username, }, } try: self.__mr.discussions.create(_obj) except Exception: try: del _obj['position']['old_line'] self.__mr.discussions.create(_obj) except Exception: logging.error('Set note {note} failed'.format(note=value)) def GetNote(self, input): return Note( input.get('author', {}).get('username', 'Unkwown user'), input.get('body', ''), input.get('position', {}).get('new_path', ''), input.get('position', {}).get('new_line', -1), input.get('position', {}).get('old_line', None), input.get('resolved', False), input.get('resolvable', True), input, ) @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def GetNotes(self): mr = self.__mr res = [] for discussion in mr.discussions.list(all=True): for note in mr.discussions.get(discussion.id).attributes['notes']: res.append(self.GetNote(note)) return res @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def ResolveNote(self, value: Note): note = self.__mr.discussions.get(value.reference.get('id')) note.resolve = True note.save() @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def GetCurrentStatus(self) -> dict: return self.__mr
msg= "More than one docker service is labeled as main service") node_details = await _get_node_details( app, client, list_running_services_with_uuid[0]) return node_details except aiodocker.exceptions.DockerError as err: log.exception("Error while accessing container with uuid: %s", node_uuid) raise exceptions.GenericDockerError( "Error while accessing container", err) from err @retry( wait=wait_fixed(2), stop=stop_after_attempt(3), reraise=True, retry=retry_if_exception_type(ClientConnectionError), ) async def _save_service_state(service_host_name: str, session: aiohttp.ClientSession): response: ClientResponse async with session.post( url=f"http://{service_host_name}/state", timeout=ServicesCommonSettings(). director_dynamic_service_save_timeout, ) as response: try: response.raise_for_status() except ClientResponseError as err:
APP_FRONTEND_CACHED_INDEXES_KEY = f"{__name__}.cached_indexes" APP_FRONTEND_CACHED_STATICS_JSON_KEY = f"{__name__}.cached_statics_json" # NOTE: saved as a separate item to config STATIC_WEBSERVER_SETTINGS_KEY = f"{__name__}.StaticWebserverModuleSettings" # # This retry policy aims to overcome the inconvenient fact that the swarm # orchestrator does not guaranteed the order in which services are started. # # Here the web-server needs to pull some files from the web-static service # which might still not be ready. # # RETRY_ON_STARTUP_POLICY = dict( stop=stop_after_attempt(5), wait=wait_fixed(1.5), before=before_log(log, logging.WARNING), retry=retry_if_exception_type(ClientConnectionError), reraise=True, ) def assemble_settings(app: web.Application) -> StaticWebserverModuleSettings: """creates stores and returns settings for this module""" settings = StaticWebserverModuleSettings() app[STATIC_WEBSERVER_SETTINGS_KEY] = settings return settings def get_settings(app: web.Application) -> StaticWebserverModuleSettings:
class ContainerWrapper: def __init__(self, container_url: str) -> None: self.client = ContainerClient.from_container_url(container_url) @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def upload_file(self, file_path: str, blob_name: str) -> None: with open(file_path, "rb") as handle: self.client.upload_blob(name=blob_name, data=handle, overwrite=True, max_concurrency=10) return None def upload_file_data(self, data: str, blob_name: str) -> None: self.client.upload_blob(name=blob_name, data=data, overwrite=True, max_concurrency=10) def upload_dir(self, dir_path: str, recursive: bool = True) -> None: for path in glob.glob(os.path.join(dir_path, "**"), recursive=recursive): if os.path.isfile(path): blob_name = os.path.relpath(path, start=dir_path) self.upload_file(path, blob_name) @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def delete_blob(self, blob_name: str) -> None: self.client.delete_blob(blob_name) return None @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def download_blob(self, blob_name: str) -> bytes: return cast(bytes, self.client.download_blob(blob_name).content_as_bytes()) @retry( stop=stop_after_attempt(10), wait=wait_random(min=1, max=3), retry=retry_if_exception_type(), before_sleep=before_sleep, reraise=True, ) def list_blobs(self, *, name_starts_with: Optional[str] = None) -> List[str]: result = [ x.name for x in self.client.list_blobs(name_starts_with=name_starts_with) ] return cast(List[str], result)
import asyncio import time import pytest from servicelib.aiohttp import monitor_slow_callbacks from servicelib.aiohttp.aiopg_utils import DatabaseError, retry from tenacity.stop import stop_after_attempt from tenacity.wait import wait_fixed async def slow_task(delay): time.sleep(delay) @retry(wait=wait_fixed(1), stop=stop_after_attempt(2)) async def fails_to_reach_pg_db(): raise DatabaseError @pytest.fixture def incidents_manager(event_loop): incidents = [] monitor_slow_callbacks.enable(slow_duration_secs=0.2, incidents=incidents) f1a = asyncio.ensure_future(slow_task(0.3), loop=event_loop) f1b = asyncio.ensure_future(slow_task(0.3), loop=event_loop) f1c = asyncio.ensure_future(slow_task(0.4), loop=event_loop) incidents_pg = None # aiopg_utils.monitor_pg_responsiveness.enable() f2 = asyncio.ensure_future(fails_to_reach_pg_db(), loop=event_loop)
class BaseFetchProvider: """ Base class for data fetching providers. - Override self._fetch_ to implement fetching - call self.fetch() to retrive data (wrapped in retries and safe execution guards) - override __aenter__ and __aexit__ for async context """ DEFAULT_RETRY_CONFIG = { 'wait': wait.wait_random_exponential(), "stop": stop.stop_after_attempt(200), "reraise": True } def __init__(self, event: FetchEvent, retry_config=None) -> None: """ Args: event (FetchEvent): the event desciring what we should fetch retry_config (dict): Tenacity.retry config (@see https://tenacity.readthedocs.io/en/latest/api.html#retry-main-api) for retrying fetching """ # convert the event as needed and save it self._event = self.parse_event(event) self._url = event.url self._retry_config = retry_config if retry_config is not None else self.DEFAULT_RETRY_CONFIG def parse_event(self, event: FetchEvent) -> FetchEvent: """ Parse the event (And config within it) into the right object type Args: event (FetchEvent): the event to be parsed Returns: FetchEvent: an event deriving from FetchEvent """ return event async def fetch(self): """ Fetch and return data. Calls self._fetch_ with a retry mechanism """ attempter = retry(**self._retry_config)(self._fetch_) res = await attempter() return res async def process(self, data): try: return await self._process_(data) except: logger.exception("Failed to process fetched data") raise async def __aenter__(self): return self async def __aexit__(self, exc_type=None, exc_val=None, tb=None): pass async def _fetch_(self): """ Internal fetch operation called by self.fetch() Override this method to implement a new fetch provider """ pass async def _process_(self, data): return data def set_retry_config(self, retry_config: dict): """ Set the configuration for retrying failed fetches @see self.DEFAULT_RETRY_CONFIG Args: retry_config (dict): Tenacity retry config """ self._retry_config = retry_config
class ExperimentalSpotifyPodcastAPI: """Representation of the experimental Spotify podcast API.""" def __init__(self): self._bearer: Optional[str] = None self._bearer_expires: Optional[dt.datetime] = None self._auth_lock = RLock() @retry(wait=wait_exponential(), stop=stop_after_attempt(7)) def _authenticate(self): """ Retrieves a Bearer token for the experimental Spotify API, valid 1 hour. Generally follows the steps outlined here: https://developer.spotify.com/documentation/general/guides/authorization/code-flow/ (with a few exceptions) """ with self._auth_lock: logger.info("Retrieving Bearer for experimental Spotify API...") logger.debug("Generating secrets") state = random_string(32) code_verifier = random_string(64) code_challenge = base64.b64encode( hashlib.sha256(code_verifier.encode("utf-8")).digest() ).decode("utf-8") # Fix up format of code_challenge for spotify code_challenge = re.sub(r"=+$", "", code_challenge) code_challenge = code_challenge.replace("/", "_") code_challenge = code_challenge.replace("+", "-") logger.trace("state = {}", state) logger.trace("code_verifier = {}", code_verifier) logger.trace("code_challenge = {}", code_challenge) logger.debug("Requesting User Authorization") response = requests.get( "https://accounts.spotify.com/oauth2/v2/auth", params={ "response_type": "code", "client_id": CLIENT_ID, "scope": "streaming ugc-image-upload user-read-email user-read-private", "redirect_uri": "https://podcasters.spotify.com", "code_challenge": code_challenge, "code_challenge_method": "S256", "state": state, "response_mode": "web_message", # TODO: Figure out if there is a way to get pure JSON "prompt": "none", }, cookies={ "sp_dc": SP_DC, "sp_key": SP_KEY, }, ) response.raise_for_status() # We get some weird HTML here that contains some JS html = response.text match = re.search(r"const authorizationResponse = (.*?);", html, re.DOTALL) json_str = match.group(1) # The extracted string isn't strictly valid JSON due to some missing quotes, # but PyYAML loads it fine auth_response = yaml.safe_load(json_str) # Confirm that auth was successful assert auth_response["type"] == "authorization_response" assert auth_response["response"]["state"] == state auth_code = auth_response["response"]["code"] logger.trace("auth_code = {}", auth_code) logger.debug("Requesting Bearer Token") response = requests.post( "https://accounts.spotify.com/api/token", data={ "grant_type": "authorization_code", "client_id": CLIENT_ID, "code": auth_code, "redirect_uri": "https://podcasters.spotify.com", "code_verifier": code_verifier, }, ) response.raise_for_status() response_json = response.json() self._bearer = response_json["access_token"] expires_in = response_json["expires_in"] self._bearer_expires = dt.datetime.now() + dt.timedelta(seconds=expires_in) logger.trace("bearer = {}", self._bearer) logger.success("Bearer token retrieved!") def _ensure_auth(self): """Checks if Bearer token expires soon. If so, requests a new one.""" with self._auth_lock: if self._bearer is None or self._bearer_expires < ( dt.datetime.now() - dt.timedelta(minutes=5) ): self._authenticate() @staticmethod def _build_url(*path: str) -> str: return f"{BASE_URL}{'/'.join(path)}" @staticmethod def _date_params(start: dt.date, end: dt.date) -> Dict[str, str]: return { "start": start.isoformat(), "end": end.isoformat(), } def _request(self, url: str, *, params: Optional[Dict[str, str]] = None) -> dict: delay = DELAY_BASE for attempt in range(6): sleep(delay) self._ensure_auth() response = requests.get( url, params=params, headers={"Authorization": f"Bearer {self._bearer}"}, ) if response.status_code in (429, 502, 503, 504): delay *= 2 logger.log( ("INFO" if attempt < 3 else "WARNING"), 'Got {} for URL "{}", next delay: {}s', response.status_code, url, delay, ) continue elif response.status_code == 401: self._authenticate() continue if not response.ok: logger.error("Error in experimental API:") logger.info(response.status_code) logger.info(response.headers) logger.info(response.text) response.raise_for_status() return response.json() raise Exception("All retries failed!") def podcast_followers(self, podcast_id: str, start: dt.date, end: dt.date) -> dict: """Loads historic follower data for podcast. Args: podcast_id (str): ID of the podcast to request data for. start (dt.date): Earliest date to request data for. end (dt.date): Most recent date to request data for. Returns: dict: Response data from API. """ url = self._build_url( "shows", podcast_id, "followers", ) return self._request(url, params=self._date_params(start, end)) def podcast_aggregate( self, podcast_id: str, start: dt.date, end: Optional[dt.date] = None, ) -> dict: """Loads podcast demographics data. Args: podcast_id (str): ID of the podcast to request data for. start (dt.date): Earliest date to request data for. end (Optional[dt.date], optional): Most recent date to request data for. Defaults to None. Will be set to ``start`` if None. Returns: dict: [description] """ if end is None: end = start url = self._build_url( "shows", podcast_id, "aggregate", ) return self._request(url, params=self._date_params(start, end)) def episode_performance(self, episode_id: str) -> dict: """Loads episode performance data. Args: episode_id (str): ID of the episode to request data for. Returns: dict: Response data from API. """ url = self._build_url("episodes", episode_id, "performance") return self._request(url) def episode_aggregate( self, episode_id: str, start: dt.date, end: Optional[dt.date] = None, ) -> dict: """Loads episode demographics data. Args: episode_id (str): ID of the episode to request data for. start (dt.date): Earliest date to request data for. end (Optional[dt.date], optional): Most recent date to request data for. Defaults to None. Will be set to ``start`` if None. Returns: dict: [description] """ if end is None: end = start url = self._build_url( "episodes", episode_id, "aggregate", ) return self._request(url, params=self._date_params(start, end))
def get_localhost_ip(default="127.0.0.1") -> str: """Return the IP address for localhost""" local_ip = default s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: # doesn't even have to be reachable s.connect(("10.255.255.255", 1)) local_ip = s.getsockname()[0] finally: s.close() return local_ip @retry( wait=wait_fixed(2), stop=stop_after_attempt(10), after=after_log(log, logging.WARNING), ) def get_service_published_port( service_name: str, target_ports: Optional[Union[List[int], int]] = None) -> str: # WARNING: ENSURE that service name exposes a port in # Dockerfile file or docker-compose config file # NOTE: retries since services can take some time to start client = docker.from_env() services = [ s for s in client.services.list() if str(s.name).endswith(service_name) ] if not services:
user_reviews[html.unescape(review["author"])] = { "date": dateparser.parse(review["datePublished"], locales=["de", "en"]).date(), "title": html.unescape(review["name"]), "text": html.unescape(review.get("reviewBody", "")), "rating": review["reviewRating"]["ratingValue"], } return (user_rating, user_reviews) @retry(wait=wait_exponential(), stop=stop_after_attempt(4)) def _get_apple_meta_data(**params) -> types.JSON: """Retrieve data from iTunes Search API. Returns: types.JSON: Raw JSON representation of search result. """ result = requests.get(BASE_URL, params=params) result.raise_for_status() return result.json() def _get_metadata_url(podcast: Podcast, ) -> Optional[str]: """Retrieve meta data URL from API based on podcast.
from .director_v2_models import ClusterCreate, ClusterPatch, ClusterPing from .director_v2_settings import ( DirectorV2Settings, get_client_session, get_plugin_settings, ) log = logging.getLogger(__name__) _APP_DIRECTOR_V2_CLIENT_KEY = f"{__name__}.DirectorV2ApiClient" SERVICE_HEALTH_CHECK_TIMEOUT = ClientTimeout(total=2, connect=1) # type:ignore DEFAULT_RETRY_POLICY = dict( wait=wait_random(0, 1), stop=stop_after_attempt(2), reraise=True, before_sleep=before_sleep_log(log, logging.WARNING), ) DataType = Dict[str, Any] DataBody = Union[DataType, List[DataType], None] class DirectorV2ApiClient: def __init__(self, app: web.Application) -> None: self._app = app self._settings: DirectorV2Settings = get_plugin_settings(app) async def start(self, project_id: ProjectID, user_id: UserID, **options) -> str:
import tenacity from settings_library.rabbit import RabbitSettings from tenacity.before_sleep import before_sleep_log from tenacity.stop import stop_after_attempt from tenacity.wait import wait_fixed from .helpers.utils_docker import get_localhost_ip, get_service_published_port # HELPERS ------------------------------------------------------------------------------------ log = logging.getLogger(__name__) @tenacity.retry( wait=wait_fixed(5), stop=stop_after_attempt(60), before_sleep=before_sleep_log(log, logging.INFO), reraise=True, ) async def wait_till_rabbit_responsive(url: str) -> None: connection = await aio_pika.connect(url) await connection.close() # FIXTURES ------------------------------------------------------------------------------------ @pytest.fixture(scope="function") async def rabbit_settings( docker_stack: Dict, testing_environ_vars: Dict # stack is up
class GitHubProvider(Provider): def __init__(self, args, username, token, serverurl, project, mrnumber): super().__init__(args, username, token, serverurl) self.__project = project self.__mrnum = mrnumber self.__repo = self.__get_connection() self.__pr = self.__get_pr() self._changes.AddChangeFromCollection(self.__pr.diff()) self.__isvalid = self.__pr.state == 'open' self.__isvalid &= not self.__is_draft() or self.AllowDrafts @property def Valid(self): return self.__isvalid def __is_draft(self): return self.__pr.draft def __get_github_repo(self): return self.__project @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def __get_connection(self): login = github3.login(self.Username, self.Token) if not login: raise SCABotServerCommError('Login failed. Check your credentials') res = login.repository(self.Username, self.__get_github_repo()) if not res: raise SCABotProjectNotFoundError(self.__get_github_repo()) return res @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def __get_pr(self): res = self.__repo.pull_request(self.__mrnum) if not res: raise SCABotRequestNotFoundError(self.__mrnum) return res @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def SetNote(self, value: Note): try: _ref = value.reference if not _ref: _ref = self.__get_pr().head.sha self.__pr.create_review_comment(value.body, _ref, value.path, value.lines[0]) except Exception as e: raise SCABotServerCommError(e) def GetNote(self, input): return Note( input.user.login, input.body_text, input.path, input.position, input.original_position, reference=input, ) @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def GetNotes(self): return [self.GetNote(x) for x in self.__pr.review_comments()] @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def ResolveNote(self, value: Note): for c in self.__pr.review_comments(): if c.id == value.reference.id: # Unfortunately github/v3-API doesn't allow # programatic resolving of comments # so we are going to delete them instead c.delete() break @retry(wait=wait_exponential(multiplier=1, min=10, max=120), stop=stop_after_attempt(5)) def GetCurrentStatus(self) -> dict: return self.__pr.__dict__id