def check_elastic_empty() -> bool: """ Check for elasticsearch connection Use probing settings and not the default es cluster ones so that we can handle correctly the connection rejects due to ES not fully started yet :return: """ cluster_conf = es_factory.get_cluster_config("events") max_retries = config.get("apiserver.elastic.probing.max_retries", 4) timeout = config.get("apiserver.elastic.probing.timeout", 30) es_logger = logging.getLogger("elasticsearch") log_filter = ConnectionErrorFilter( err_type=urllib3.exceptions.NewConnectionError, args_prefix=("GET", )) try: es_logger.addFilter(log_filter) for retry in range(max_retries): try: es = Elasticsearch(hosts=cluster_conf.get("hosts")) return not es.indices.get_template(name="events*") except exceptions.NotFoundError as ex: log.error(ex) return True except exceptions.ConnectionError as ex: if retry >= max_retries - 1: raise ElasticConnectionError( f"Error connecting to Elasticsearch: {str(ex)}") log.warn( f"Could not connect to ElasticSearch Service. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec" ) sleep(timeout) finally: es_logger.removeFilter(log_filter)
def start_sender(cls): if not cls.supported: return url = config.get("apiserver.statistics.url") retries = config.get("apiserver.statistics.max_retries", 5) max_backoff = config.get("apiserver.statistics.max_backoff_sec", 5) session = requests.Session() adapter = HTTPAdapter(max_retries=Retry(retries)) session.mount("http://", adapter) session.mount("https://", adapter) session.headers["Content-type"] = "application/json" WarningFilter.attach() while not ThreadsManager.terminating: try: report = cls.send_queue.get() # Set a random backoff factor each time we send a report adapter.max_retries.backoff_factor = random.random( ) * max_backoff session.post(url, data=dumps(report)) except Exception as ex: pass
def _configure(self): CORS(self.app, **config.get("apiserver.cors")) Compress(self.app) self.app.config["SECRET_KEY"] = config.get( "secure.http.session_secret.apiserver") self.app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get( "apiserver.pretty_json")
def _configure(self): CORS(self.app, **config.get("apiserver.cors")) if get_bool("CLEARML_COMPRESS_RESP", default=True): Compress(self.app) self.app.config["SECRET_KEY"] = config.get( "secure.http.session_secret.apiserver") self.app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get( "apiserver.pretty_json")
def get_guest_user(cls) -> Optional["FixedUser"]: if cls.guest_enabled(): return cls( is_guest=True, username=config.get( "services.auth.fixed_users.guest.username"), password=config.get( "services.auth.fixed_users.guest.password"), name=config.get("services.auth.fixed_users.guest.name"), company=config.get( "services.auth.fixed_users.guest.default_company"), )
def is_run_by_worker(t: Task) -> bool: """Checks if there is an active worker running the task""" update_timeout = config.get( "apiserver.workers.task_update_timeout", 600) return (t.last_worker and t.last_update and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout)
def __init__(self, redis_config_dict): self.aliases = {} for alias, alias_config in redis_config_dict.items(): alias_config = alias_config.as_plain_ordered_dict() alias_config["password"] = config.get( f"secure.redis.{alias}.password", None) is_cluster = alias_config.get("cluster", False) host = OVERRIDE_HOST or alias_config.get("host", None) if host: alias_config["host"] = host port = OVERRIDE_PORT or alias_config.get("port", None) if port: alias_config["port"] = port password = OVERRIDE_PASSWORD or alias_config.get("password", None) if password: alias_config["password"] = password if not port or not host: raise ConfigError( "Redis configuration is invalid. missing port or host", alias=alias) if is_cluster: del alias_config["cluster"] del alias_config["db"] self.aliases[alias] = RedisCluster(**alias_config) else: self.aliases[alias] = StrictRedis(**alias_config)
def get_cluster_config(cls, cluster_name): """ Returns cluster config for the specified cluster path :param cluster_name: Dot separated cluster path in the configuration file :return: config section for the cluster :raises MissingClusterConfiguration: in case no config section is found for the cluster """ cluster_key = ".".join(("hosts.elastic", cluster_name)) cluster_config = config.get(cluster_key, None) if not cluster_config: raise MissingClusterConfiguration(cluster_name) def set_host_prop(key, value): for entry in cluster_config.get("hosts", []): entry[key] = value host, port = cls.get_override(cluster_name) if host: set_host_prop("host", host) if port: set_host_prop("port", port) return cluster_config
def _check_updates(self): update_interval_sec = max( float( config.get( "apiserver.check_for_updates.check_interval_sec", 60 * 60 * 24, )), 60 * 5, ) while not ThreadsManager.terminating: # noinspection PyBroadException try: response = self._check_new_version_available() if response: if response.patch_upgrade: log.info( f"{self.component_name.upper()} new package available: upgrade to v{response.version} " f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}" ) else: log.info( f"{self.component_name.upper()} new version available: upgrade to v{response.version}" f" is recommended!") except Exception: log.exception("Failed obtaining updates") sleep(update_interval_sec)
def _pre_populate(company_id: str, zip_file: str): if not zip_file or not Path(zip_file).is_file(): msg = f"Invalid pre-populate zip file: {zip_file}" if config.get("apiserver.pre_populate.fail_on_error", False): log.error(msg) raise ValueError(msg) else: log.warning(msg) else: log.info(f"Pre-populating using {zip_file}") PrePopulate.import_from_zip( zip_file, artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None), )
class GetTokenRequest(Base): """ User requests a token """ expiration_sec = IntField(validators=Max( config.get("apiserver.auth.max_expiration_sec")), nullable=True) """ Expiration time for token in seconds. """
def get_queue_metrics( self, company_id: str, from_date: float, to_date: float, interval: int, queue_ids: Sequence[str], ) -> dict: """ Get the company queue metrics in the specified time range. Returned as date histograms of average values per queue and metric type. The from_date is extended by 'metrics_before_from_date' seconds from queues.conf due to possibly small amount of points. The default extension is 3600s In case no queue ids are specified the avg across all the company queues is calculated for each metric """ # self._log_current_metrics(company, queue_ids=queue_ids) if from_date >= to_date: raise bad_request.FieldsValueError( "from_date must be less than to_date") seconds_before = config.get("services.queues.metrics_before_from_date", 3600) must_terms = [ QueryBuilder.dates_range(from_date - seconds_before, to_date) ] if queue_ids: must_terms.append(QueryBuilder.terms("queue", queue_ids)) es_req = { "size": 0, "query": { "bool": { "must": must_terms } }, "aggs": self._get_dates_agg(interval), } with translate_errors_context(), TimingContext("es", "get_queue_metrics"): res = self._search_company_metrics(company_id, es_req) if "aggregations" not in res: return {} date_metrics = [ dict( timestamp=d["key"], queue_metrics=self._extract_queue_metrics( d["queues"]["buckets"]), ) for d in res["aggregations"]["dates"]["buckets"] if d["doc_count"] > 0 ] if queue_ids: return self._datetime_histogram_per_queue(date_metrics) return self._average_datetime_histogram(date_metrics)
def is_guest_endpoint(cls, service, action): """ Validate a potential guest user, This method will verify the user is indeed the guest user, and that the guest user may access the service/action using its username/password """ return any(ep == ".".join((service, action)) for ep in config.get( "services.auth.fixed_users.guest.allow_endpoints", []))
def update_interval(self): return timedelta(seconds=max( float( config.get( "apiserver.check_for_updates.check_interval_sec", 60 * 60 * 24, )), 60 * 5, ))
def test_project_aggregations(self): """This test requires user with user_auth_only... credentials in db""" user2_client = APIClient( api_key=config.get("apiclient.user_auth_only"), secret_key=config.get("apiclient.user_auth_only_secret"), base_url=f"http://localhost:8008/v2.13", ) child = self._temp_project(name="Aggregation/Pr1", client=user2_client) project = self.api.projects.get_all_ex( name="^Aggregation$").projects[0].id child_project = self.api.projects.get_all_ex(id=[child]).projects[0] self.assertEqual(child_project.parent.id, project) user = self.api.users.get_current_user().user.id # test aggregations on project with empty subprojects res = self.api.users.get_all_ex(active_in_projects=[project]) self.assertEqual(res.users, []) res = self.api.projects.get_all_ex(id=[project], active_users=[user]) self.assertEqual(res.projects, []) res = self.api.models.get_frameworks(projects=[project]) self.assertEqual(res.frameworks, []) res = self.api.tasks.get_types(projects=[project]) self.assertEqual(res.types, []) res = self.api.projects.get_task_parents(projects=[project]) self.assertEqual(res.parents, []) # test aggregations with non-empty subprojects task1 = self._temp_task(project=child) self._temp_task(project=child, parent=task1) framework = "Test framework" self._temp_model(project=child, framework=framework) res = self.api.users.get_all_ex(active_in_projects=[project]) self._assert_ids(res.users, [user]) res = self.api.projects.get_all_ex(id=[project], active_users=[user]) self._assert_ids(res.projects, [project]) res = self.api.projects.get_task_parents(projects=[project]) self._assert_ids(res.parents, [task1]) res = self.api.models.get_frameworks(projects=[project]) self.assertEqual(res.frameworks, [framework]) res = self.api.tasks.get_types(projects=[project]) self.assertEqual(res.types, ["testing"])
def _check_new_version_available(self) -> Optional[_VersionResponse]: url = config.get( "apiserver.check_for_updates.url", "https://updates.trains.allegro.ai/updates", ) uid = Settings.get_by_key("server.uuid") response = requests.get( url, json={ "versions": { self.component_name: str(get_version()) }, "uid": uid }, timeout=float( config.get("apiserver.check_for_updates.request_timeout_sec", 3.0)), ) if not response.ok: return response = response.json().get(self.component_name) if not response: return latest_version = response.get("version") if not latest_version: return cur_version = Version(get_version()) latest_version = Version(latest_version) if cur_version >= latest_version: return return self._VersionResponse( version=str(latest_version), patch_upgrade=(latest_version.major == cur_version.major and latest_version.minor == cur_version.minor), description=response.get("description").split("\r\n"), )
def from_scroll_id(cls, scroll_id: str): try: return cls(**jwt.decode( scroll_id, key=config.get( "services.events.events_retrieval.scroll_id_key", "1234567890"), )) except jwt.PyJWTError: raise ValueError("Invalid Scroll ID")
def __init__(self, events_es=None, redis=None): self.es = events_es or es_factory.connect("events") self._metrics = EventMetrics(self.es) self._skip_iteration_for_metric = set( config.get("services.events.ignore_iteration.metrics", []) ) self.redis = redis or redman.connection("apiserver") self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis) self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis) self.events_iterator = EventsIterator(es=self.es)
def from_config(cls) -> Sequence["FixedUser"]: users = [ cls(**user) for user in config.get("apiserver.auth.fixed_users.users", []) ] if cls.guest_enabled(): users.insert(0, cls.get_guest_user()) return users
def get_cache_manager(cls): if not cls._cache_manager: cls._cache_manager = RedisCacheManager( state_class=cls.GetManyScrollState, redis=redman.connection("apiserver"), expiration_interval=config.get( "services._mongo.scroll_state_expiration_seconds", 600), ) return cls._cache_manager
def get_credentials( cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]: cluster_config = cluster_config or cls.get_cluster_config(cluster_name) if not cluster_config.get("secure", True): return None elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None) if not elastic_user: return None elastic_password = OVERRIDE_PASSWORD or config.get( "secure.elastic.password", None) if not elastic_password: raise MissingPasswordForElasticUser( f"cluster={cluster_name}, username={elastic_user}") return elastic_user, elastic_password
def login(call: APICall, *_, **__): """ Generates a token based on the authenticated user (intended for use with credentials) """ call.result.data_model = AuthBLL.get_token_for_user( user_id=call.identity.user, company_id=call.identity.company, expiration_sec=call.data_model.expiration_sec, ) # Add authorization cookie call.result.cookies[config.get("apiserver.auth.session_auth_cookie_name" )] = call.result.data_model.token
def update_featured_projects_order(cls): order = config.get("services.projects.featured.order", []) if not order: return public_default = config.get("services.projects.featured.public_default", 9999) def get_index(p: Project): for index, entry in enumerate(order): if ( entry.get("id", None) == p.id or entry.get("name", None) == p.name or ("name_regex" in entry and re.match(entry["name_regex"], p.name)) ): return index return public_default for project in cls.project_cls.get_many_public(projection=["id", "name"]): featured_index = get_index(project) cls.project_cls.objects(id=project.id).update(featured=featured_index)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): tasks: Sequence[str] = ListField( items_types=str, validators=[ Length( minimum_value=1, maximum_value=config.get( "services.tasks.multi_task_histogram_limit", 10), ) ], )
def _create_api_call(self, req): call = None try: # Parse the request path path = req.path if self._request_strip_prefix and path.startswith( self._request_strip_prefix ): path = path[len(self._request_strip_prefix) :] endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(path) # Resolve authorization: if cookies contain an authorization token, use it as a starting point. # in any case, request headers always take precedence. auth_cookie = req.cookies.get( config.get("apiserver.auth.session_auth_cookie_name") ) headers = ( {} if not auth_cookie else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"} ) headers.update( list(req.headers.items()) ) # add (possibly override with) the headers # Construct call instance call = APICall( endpoint_name=endpoint_name, remote_addr=req.remote_addr, endpoint_version=endpoint_version, headers=headers, files=req.files, host=req.host, auth_cookie=auth_cookie, ) except PathParsingError as ex: call = self._call_or_empty_with_error(call, req, ex.args[0], 400) call.log_api = False except BadRequest as ex: call = self._call_or_empty_with_error(call, req, ex.description, 400) except BaseError as ex: call = self._call_or_empty_with_error( call, req, ex.msg, ex.code, ex.subcode ) except Exception as ex: log.exception("Error creating call") call = self._call_or_empty_with_error( call, req, ex.args[0] if ex.args else type(ex).__name__, 500 ) return call
def _init_dbs(self): db.initialize() with distributed_lock( name=self._get_db_instance_key(), timeout=config.get("apiserver.db_init_timout", 120), ): upgrade_monitoring = config.get( "apiserver.elastic.upgrade_monitoring.v16_migration_verification", True) try: empty_es = check_elastic_empty() except ElasticConnectionError as err: if not upgrade_monitoring: raise log.error(err) info.es_connection_error = True empty_db = check_mongo_empty() if (upgrade_monitoring and not empty_db and (info.es_connection_error or empty_es) and get_last_server_version() < Version("0.16.0")): log.info(f"ES database seems not migrated") info.missed_es_upgrade = True if info.es_connection_error and not info.missed_es_upgrade: raise Exception( "Error starting server: failed connecting to ElasticSearch service" ) if not info.missed_es_upgrade: init_es_data() init_mongo_data() if (not info.missed_es_upgrade and empty_db and config.get("apiserver.pre_populate.enabled", False)): pre_populate_data()
def initialize(cls): db_entries = config.get("hosts.mongo", {}) missing = [] log.info("Initializing database connections") override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY) override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None) override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None) if override_connection_string: log.info( f"Using override mongodb connection string {override_connection_string}" ) else: if override_hostname: log.info(f"Using override mongodb host {override_hostname}") if override_port: log.info(f"Using override mongodb port {override_port}") for key, alias in get_items(Database).items(): if key not in db_entries: missing.append(key) continue entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key)) if override_connection_string: entry.host = override_connection_string else: if override_hostname: entry.host = furl( entry.host).set(host=override_hostname).url if override_port: entry.host = furl(entry.host).set(port=override_port).url try: entry.validate() log.info("Registering connection to %(alias)s (%(host)s)" % entry.to_struct()) register_connection(**entry.to_struct()) cls._entries.append(entry) except ValidationError as ex: raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0])) if missing: raise ValueError("Missing database configuration for %s" % ", ".join(missing))
def unregister_worker(self, company_id: str, user_id: str, worker: str) -> None: """ Unregister a worker :param company_id: worker's company ID :param user_id: user ID under which this worker is running :param worker: worker ID :raise bad_request.WorkerNotRegistered: the worker was not previously registered """ with TimingContext("redis", "workers_unregister"): res = self.redis.delete( company_id, self._get_worker_key(company_id, user_id, worker)) if not res and not config.get("apiserver.workers.auto_unregister", False): raise bad_request.WorkerNotRegistered(worker=worker)
def __init__(self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_): super(Token, self).__init__(AuthType.bearer_token, identity=identity, entities=entities) self.exp = exp self.iat = iat self.nbf = nbf self._env = env or config.get('env', '<unknown>')
def __init__( self, api_key=None, secret_key=None, base_url=None, impersonated_user_id=None, session_token=None, ): if not session_token: self.api_key = (api_key or os.environ.get("SM_API_KEY") or config.get("apiclient.api_key")) if not self.api_key: raise ValueError( "APIClient requires api_key in constructor or config") self.secret_key = (secret_key or os.environ.get("SM_API_SECRET") or config.get("apiclient.secret_key")) if not self.secret_key: raise ValueError( "APIClient requires secret_key in constructor or config") self.base_url = (base_url or os.environ.get("SM_API_URL") or config.get("apiclient.base_url")) if not self.base_url: raise ValueError( "APIClient requires base_url in constructor or config") if self.base_url.endswith("/"): self.base_url = self.base_url[:-1] self.session_token = session_token # create http session self.http_session = requests.session() retries = config.get("apiclient.retries", 7) backoff_factor = config.get("apiclient.backoff_factor", 0.3) status_forcelist = config.get("apiclient.status_forcelist", (500, 502, 504)) retry = Retry( total=retries, read=retries, connect=retries, backoff_factor=backoff_factor, status_forcelist=status_forcelist, ) adapter = HTTPAdapter(max_retries=retry) self.http_session.mount("http://", adapter) self.http_session.mount("https://", adapter) if impersonated_user_id: self.http_session.headers[ "X-ClearML-Impersonate-As"] = impersonated_user_id if not self.session_token: self.login()