def _to_sqlalchemy_filtering_statement(sql_statement, session): key_type = sql_statement.get('type') key_name = sql_statement.get('key') value = sql_statement.get('value') comparator = sql_statement.get('comparator') if SearchUtils.is_metric(key_type, comparator): entity = SqlLatestMetric value = float(value) elif SearchUtils.is_param(key_type, comparator): entity = SqlParam elif SearchUtils.is_tag(key_type, comparator): entity = SqlTag elif SearchUtils.is_attribute(key_type, comparator): return None else: raise MlflowException("Invalid search expression type '%s'" % key_type, error_code=INVALID_PARAMETER_VALUE) # validity of the comparator is checked in SearchUtils.parse_search_filter() op = SearchUtils.filter_ops.get(comparator) if op: return (session.query(entity).filter(entity.key == key_name, op(entity.value, value)).subquery()) else: return None
def test_order_by_metric_with_nans_infs_nones(): metric_vals_str = ["nan", "inf", "-inf", "-1000", "0", "1000", "None"] runs = [ Run( run_info=RunInfo( run_id=x, run_uuid=x, experiment_id=0, user_id="user", status=RunStatus.to_string(RunStatus.FINISHED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE, ), run_data=RunData( metrics=[Metric("x", None if x == "None" else float(x), 1, 0) ]), ) for x in metric_vals_str ] sorted_runs_asc = [ x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x asc"]) ] sorted_runs_desc = [ x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x desc"]) ] # asc assert ["-inf", "-1000", "0", "1000", "inf", "nan", "None"] == sorted_runs_asc # desc assert ["inf", "1000", "0", "-1000", "-inf", "nan", "None"] == sorted_runs_desc
def _to_sqlalchemy_filtering_statement(sql_statement, session): key_type = sql_statement.get("type") key_name = sql_statement.get("key") value = sql_statement.get("value") comparator = sql_statement.get("comparator").upper() if SearchUtils.is_metric(key_type, comparator): entity = SqlLatestMetric value = float(value) elif SearchUtils.is_param(key_type, comparator): entity = SqlParam elif SearchUtils.is_tag(key_type, comparator): entity = SqlTag elif SearchUtils.is_attribute(key_type, comparator): return None else: raise MlflowException("Invalid search expression type '%s'" % key_type, error_code=INVALID_PARAMETER_VALUE) if comparator in SearchUtils.CASE_INSENSITIVE_STRING_COMPARISON_OPERATORS: op = SearchUtils.get_sql_filter_ops(entity.value, comparator) return session.query(entity).filter(entity.key == key_name, op(value)).subquery() elif comparator in SearchUtils.filter_ops: op = SearchUtils.filter_ops.get(comparator) return (session.query(entity).filter(entity.key == key_name, op(entity.value, value)).subquery()) else: return None
def _search_runs(self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token): # TODO: push search query into backend database layer if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format( SEARCH_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE) stages = set(LifecycleStage.view_type_to_stages(run_view_type)) with self.ManagedSessionMaker() as session: # Fetch the appropriate runs and eagerly load their summary metrics, params, and # tags. These run attributes are referenced during the invocation of # ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries # that are otherwise executed at attribute access time under a lazy loading model. queried_runs = session \ .query(SqlRun) \ .options(*self._get_eager_run_query_options()) \ .filter( SqlRun.experiment_id.in_(experiment_ids), SqlRun.lifecycle_stage.in_(stages)) \ .all() runs = [run.to_mlflow_entity() for run in queried_runs] filtered = SearchUtils.filter(runs, filter_string) sorted_runs = SearchUtils.sort(filtered, order_by) runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) return runs, next_page_token
def _search_runs( self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token ): def compute_next_token(current_size): next_token = None if max_results == current_size: final_offset = offset + max_results next_token = SearchUtils.create_page_token(final_offset) return next_token if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format(SEARCH_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE, ) stages = set(LifecycleStage.view_type_to_stages(run_view_type)) with self.ManagedSessionMaker() as session: # Fetch the appropriate runs and eagerly load their summary metrics, params, and # tags. These run attributes are referenced during the invocation of # ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries # that are otherwise executed at attribute access time under a lazy loading model. parsed_filters = SearchUtils.parse_search_filter(filter_string) parsed_orderby, sorting_joins = _get_orderby_clauses(order_by, session) query = session.query(SqlRun) for j in _get_sqlalchemy_filter_clauses(parsed_filters, session): query = query.join(j) # using an outer join is necessary here because we want to be able to sort # on a column (tag, metric or param) without removing the lines that # do not have a value for this column (which is what inner join would do) for j in sorting_joins: query = query.outerjoin(j) offset = SearchUtils.parse_start_offset_from_page_token(page_token) queried_runs = ( query.distinct() .options(*self._get_eager_run_query_options()) .filter( SqlRun.experiment_id.in_(experiment_ids), SqlRun.lifecycle_stage.in_(stages), *_get_attributes_filtering_clauses(parsed_filters) ) .order_by(*parsed_orderby) .offset(offset) .limit(max_results) .all() ) runs = [run.to_mlflow_entity() for run in queried_runs] next_page_token = compute_next_token(len(runs)) return runs, next_page_token
def test_bad_comparators(entity_type, bad_comparators, key, entity_value): run = Run(run_info=RunInfo( run_uuid="hi", run_id="hi", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData(metrics=[], params=[], tags=[]) ) for bad_comparator in bad_comparators: bad_filter = "{entity_type}.{key} {comparator} {value}".format( entity_type=entity_type, key=key, comparator=bad_comparator, value=entity_value) with pytest.raises(MlflowException) as e: SearchUtils.filter([run], bad_filter) assert "Invalid comparator" in str(e.value.message)
def test_space_order_by_search_runs(order_by, ascending_expected): identifier_type, identifier_name, ascending = SearchUtils.parse_order_by_for_search_runs( order_by ) assert identifier_type == "metric" assert identifier_name == "Mean Square Error" assert ascending == ascending_expected
def search_registered_models(self, filter_string): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: A filter string expression. Currently supports a single filter condition either name of model like ``name = 'model_name'`` :return: List of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects. """ parsed_filter = SearchUtils.parse_filter_for_model_registry( filter_string) if len(parsed_filter) == 0: conditions = [] elif len(parsed_filter) == 1: filter_dict = parsed_filter[0] if filter_dict["comparator"] not in \ SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS: raise MlflowException( 'Search registered models filter expression only ' 'supports the equality(=) comparator, case-sensitive' 'partial match (LIKE), and case-insensitive partial ' 'match (ILIKE). Input filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) if filter_dict["key"] == "name": if filter_dict["comparator"] == "LIKE": conditions = [ SqlRegisteredModel.name.like(filter_dict["value"]) ] elif filter_dict["comparator"] == "ILIKE": conditions = [ SqlRegisteredModel.name.ilike(filter_dict["value"]) ] else: conditions = [ SqlRegisteredModel.name == filter_dict["value"] ] else: raise MlflowException('Invalid filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) else: supported_ops = ''.join([ '(' + op + ')' for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS ]) sample_query = f'name {supported_ops} "<model_name>"' raise MlflowException( f'Invalid filter string: {filter_string}' 'Search registered models supports filter expressions like:' + sample_query, error_code=INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: if self.db_type == SQLITE: session.execute("PRAGMA case_sensitive_like = true;") sql_registered_models = session.query(SqlRegisteredModel).filter( *conditions).all() registered_models = [ rm.to_mlflow_entity() for rm in sql_registered_models ] return registered_models
def _get_orderby_clauses(self, order_by_list: List[str]) -> List[dict]: type_dict = { "metric": "latest_metrics", "parameter": "params", "tag": "tags" } sort_clauses = [] if order_by_list: for order_by_clause in order_by_list: (key_type, key, ascending) = SearchUtils. \ parse_order_by_for_search_runs(order_by_clause) sort_order = "asc" if ascending else "desc" if not SearchUtils.is_attribute(key_type, "="): key_type = type_dict[key_type] sort_clauses.append({ f'{key_type}.value': { 'order': sort_order, "nested": { "path": key_type, "filter": { "term": { f'{key_type}.key': key } } } } }) else: sort_clauses.append({key: {'order': sort_order}}) sort_clauses.append({"start_time": {'order': "desc"}}) sort_clauses.append({"run_id": {'order': "asc"}}) return sort_clauses
def test_pagination(page_token, max_results, matching_runs, expected_next_page_token): runs = [ Run(run_info=RunInfo( run_uuid="0", run_id="0", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData([], [], [])), Run(run_info=RunInfo( run_uuid="1", run_id="1", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData([], [], [])), Run(run_info=RunInfo( run_uuid="2", run_id="2", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData([], [], [])) ] encoded_page_token = None if page_token: encoded_page_token = base64.b64encode(json.dumps(page_token).encode("utf-8")) paginated_runs, next_page_token = SearchUtils.paginate(runs, encoded_page_token, max_results) paginated_run_indices = [] for run in paginated_runs: for i, r in enumerate(runs): if r == run: paginated_run_indices.append(i) break assert paginated_run_indices == matching_runs decoded_next_page_token = None if next_page_token: decoded_next_page_token = json.loads(base64.b64decode(next_page_token)) assert decoded_next_page_token == expected_next_page_token
def test_correct_sorting(order_bys, matching_runs): runs = [ Run(run_info=RunInfo( run_uuid="9", run_id="9", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData( metrics=[Metric("key1", 121, 1, 0)], params=[Param("my_param", "A")], tags=[])), Run(run_info=RunInfo( run_uuid="8", run_id="8", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FINISHED), start_time=1, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData( metrics=[Metric("key1", 123, 1, 0)], params=[Param("my_param", "A")], tags=[RunTag("tag1", "C")])), Run(run_info=RunInfo( run_uuid="7", run_id="7", experiment_id=1, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=1, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData( metrics=[Metric("key1", 125, 1, 0)], params=[Param("my_param", "B")], tags=[RunTag("tag1", "D")])), ] sorted_runs = SearchUtils.sort(runs, order_bys) sorted_run_indices = [] for run in sorted_runs: for i, r in enumerate(runs): if r == run: sorted_run_indices.append(i) break assert sorted_run_indices == matching_runs
def compute_next_token(current_size): next_token = None if max_results == current_size: final_offset = offset + max_results next_token = SearchUtils.create_page_token(final_offset) return next_token
def test_correct_filtering(filter_string, matching_runs): runs = [ Run(run_info=RunInfo( run_uuid="hi", run_id="hi", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData( metrics=[Metric("key1", 121, 1, 0)], params=[Param("my_param", "A")], tags=[])), Run(run_info=RunInfo( run_uuid="hi2", run_id="hi2", experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FINISHED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData( metrics=[Metric("key1", 123, 1, 0)], params=[Param("my_param", "A")], tags=[RunTag("tag1", "C")])), Run(run_info=RunInfo( run_uuid="hi3", run_id="hi3", experiment_id=1, user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED), start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE), run_data=RunData( metrics=[Metric("key1", 125, 1, 0)], params=[Param("my_param", "B")], tags=[RunTag("tag1", "D")])), ] filtered_runs = SearchUtils.filter(runs, filter_string) assert set(filtered_runs) == set([runs[i] for i in matching_runs])
def _parse_search_registered_models_order_by(cls, order_by_list): """Sorts a set of registered models based on their natural ordering and an overriding set of order_bys. Registered models are naturally ordered first by name ascending. """ clauses = [] if order_by_list: for order_by_clause in order_by_list: attribute_token, ascending = \ SearchUtils.parse_order_by_for_search_registered_models(order_by_clause) if attribute_token == SqlRegisteredModel.name.key: field = SqlRegisteredModel.name elif attribute_token in SearchUtils.VALID_TIMESTAMP_ORDER_BY_KEYS: field = SqlRegisteredModel.last_updated_time else: raise MlflowException( f"Invalid order by key '{attribute_token}' specified." f"Valid keys are " f"'{SearchUtils.RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS}'", error_code=INVALID_PARAMETER_VALUE) if ascending: clauses.append(field.asc()) else: clauses.append(field.desc()) clauses.append(SqlRegisteredModel.name.asc()) return clauses
def _search_runs(self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token): if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format( SEARCH_MAX_RESULTS_THRESHOLD, max_results), databricks_pb2.INVALID_PARAMETER_VALUE) runs = [] for experiment_id in experiment_ids: run_infos = self._list_run_infos(experiment_id, run_view_type) runs.extend(self.get_run(r.run_id) for r in run_infos) filtered = SearchUtils.filter(runs, filter_string) sorted_runs = SearchUtils.sort(filtered, order_by) runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) return runs, next_page_token
def _get_attributes_filtering_clauses(parsed): clauses = [] for sql_statement in parsed: key_type = sql_statement.get("type") key_name = sql_statement.get("key") value = sql_statement.get("value") comparator = sql_statement.get("comparator").upper() if SearchUtils.is_attribute(key_type, comparator): # key_name is guaranteed to be a valid searchable attribute of entities.RunInfo # by the call to parse_search_filter attribute = getattr(SqlRun, SqlRun.get_attribute_name(key_name)) if comparator in SearchUtils.CASE_INSENSITIVE_STRING_COMPARISON_OPERATORS: op = SearchUtils.get_sql_filter_ops(attribute, comparator) clauses.append(op(value)) elif comparator in SearchUtils.filter_ops: op = SearchUtils.filter_ops.get(comparator) clauses.append(op(attribute, value)) return clauses
def _search_runs(self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token): if page_token: raise MlflowException("SQLAlchemy-backed tracking stores do not yet support pagination" "tokens.") # TODO: push search query into backend database layer if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException("Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format(SEARCH_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: runs = [run.to_mlflow_entity() for exp in experiment_ids for run in self._list_runs(session, exp, run_view_type)] filtered = SearchUtils.filter(runs, filter_string) runs = SearchUtils.sort(filtered, order_by)[:max_results] return runs, None
def _list_experiments( self, ids=None, names=None, view_type=ViewType.ACTIVE_ONLY, max_results=None, page_token=None, eager=False, ): """ :param eager: If ``True``, eagerly loads each experiments's tags. If ``False``, these tags are not eagerly loaded and will be loaded if/when their corresponding object properties are accessed from a resulting ``SqlExperiment`` object. """ stages = LifecycleStage.view_type_to_stages(view_type) conditions = [SqlExperiment.lifecycle_stage.in_(stages)] if ids and len(ids) > 0: int_ids = [int(eid) for eid in ids] conditions.append(SqlExperiment.experiment_id.in_(int_ids)) if names and len(names) > 0: conditions.append(SqlExperiment.name.in_(names)) max_results_for_query = None if max_results is not None: max_results_for_query = max_results + 1 def compute_next_token(current_size): next_token = None if max_results_for_query == current_size: final_offset = offset + max_results next_token = SearchUtils.create_page_token(final_offset) return next_token with self.ManagedSessionMaker() as session: query_options = self._get_eager_experiment_query_options( ) if eager else [] if max_results is not None: offset = SearchUtils.parse_start_offset_from_page_token( page_token) queried_experiments = (session.query(SqlExperiment).options( *query_options).order_by( SqlExperiment.experiment_id).filter( *conditions).offset(offset).limit( max_results_for_query).all()) else: queried_experiments = (session.query(SqlExperiment).options( *query_options).filter(*conditions).all()) experiments = [ exp.to_mlflow_entity() for exp in queried_experiments ] if max_results is not None: return PagedList(experiments[:max_results], compute_next_token(len(experiments))) else: return PagedList(experiments, None)
def test_filter_runs_by_start_time(): runs = [ Run( run_info=RunInfo( run_uuid=run_id, run_id=run_id, experiment_id=0, user_id="user-id", status=RunStatus.to_string(RunStatus.FINISHED), start_time=idx, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE, ), run_data=RunData(), ) for idx, run_id in enumerate(["a", "b", "c"]) ] assert SearchUtils.filter(runs, "attribute.start_time >= 0") == runs assert SearchUtils.filter(runs, "attribute.start_time > 1") == runs[2:] assert SearchUtils.filter(runs, "attribute.start_time = 2") == runs[2:]
def search_runs(self, experiment_ids, filter_string, run_view_type, max_results=SEARCH_MAX_RESULTS_THRESHOLD, order_by=None): # TODO: push search query into backend database layer if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format( SEARCH_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: runs = [ run.to_mlflow_entity() for exp in experiment_ids for run in self._list_runs(session, exp, run_view_type) ] filtered = SearchUtils.filter(runs, filter_string) return SearchUtils.sort(filtered, order_by)[:max_results]
def _get_orderby_clauses(order_by_list, session): """Sorts a set of runs based on their natural ordering and an overriding set of order_bys. Runs are naturally ordered first by start time descending, then by run id for tie-breaking. """ clauses = [] ordering_joins = [] clause_id = 0 # contrary to filters, it is not easily feasible to separately handle sorting # on attributes and on joined tables as we must keep all clauses in the same order if order_by_list: for order_by_clause in order_by_list: clause_id += 1 (key_type, key, ascending) = SearchUtils.parse_order_by_for_search_runs(order_by_clause) if SearchUtils.is_attribute(key_type, "="): order_value = getattr(SqlRun, SqlRun.get_attribute_name(key)) else: if SearchUtils.is_metric(key_type, "="): # any valid comparator entity = SqlLatestMetric elif SearchUtils.is_tag(key_type, "="): entity = SqlTag elif SearchUtils.is_param(key_type, "="): entity = SqlParam else: raise MlflowException( "Invalid identifier type '%s'" % key_type, error_code=INVALID_PARAMETER_VALUE, ) # build a subquery first because we will join it in the main request so that the # metric we want to sort on is available when we apply the sorting clause subquery = session.query(entity).filter(entity.key == key).subquery() ordering_joins.append(subquery) order_value = subquery.c.value # sqlite does not support NULLS LAST expression, so we sort first by # presence of the field (and is_nan for metrics), then by actual value # As the subqueries are created independently and used later in the # same main query, the CASE WHEN columns need to have unique names to # avoid ambiguity if SearchUtils.is_metric(key_type, "="): clauses.append( sql.case( [(subquery.c.is_nan.is_(True), 1), (order_value.is_(None), 1)], else_=0 ).label("clause_%s" % clause_id) ) else: # other entities do not have an 'is_nan' field clauses.append( sql.case([(order_value.is_(None), 1)], else_=0).label("clause_%s" % clause_id) ) if ascending: clauses.append(order_value) else: clauses.append(order_value.desc()) clauses.append(SqlRun.start_time.desc()) clauses.append(SqlRun.run_uuid) return clauses, ordering_joins
def search_model_versions(self, filter_string): """ Search for model versions in backend that satisfy the filter criteria. :param filter_string: A filter string expression. Currently supports a single filter condition either name of model like ``name = 'model_name'`` or ``run_id = '...'``. :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion` objects. """ parsed_filter = SearchUtils.parse_filter_for_model_versions( filter_string) if len(parsed_filter) == 0: conditions = [] elif len(parsed_filter) == 1: filter_dict = parsed_filter[0] if filter_dict["comparator"] != "=": raise MlflowException( "Model Registry search filter only supports equality(=) " "comparator. Input filter string: %s" % filter_string, error_code=INVALID_PARAMETER_VALUE, ) if filter_dict["key"] == "name": conditions = [SqlModelVersion.name == filter_dict["value"]] elif filter_dict["key"] == "source_path": conditions = [SqlModelVersion.source == filter_dict["value"]] elif filter_dict["key"] == "run_id": conditions = [SqlModelVersion.run_id == filter_dict["value"]] else: raise MlflowException("Invalid filter string: %s" % filter_string, error_code=INVALID_PARAMETER_VALUE) else: raise MlflowException( "Model Registry expects filter to be one of " "\"name = '<model_name>'\" or " "\"source_path = '<source_path>'\" or \"run_id = '<run_id>'." "Input filter string: %s. " % filter_string, error_code=INVALID_PARAMETER_VALUE, ) with self.ManagedSessionMaker() as session: conditions.append( SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL) sql_model_version = session.query(SqlModelVersion).filter( *conditions).all() model_versions = [ mv.to_mlflow_entity() for mv in sql_model_version ] return PagedList(model_versions, None)
def list_experiments( self, view_type=ViewType.ACTIVE_ONLY, max_results=None, page_token=None, ): """ :param view_type: Qualify requested type of experiments. :param max_results: If passed, specifies the maximum number of experiments desired. If not passed, all experiments will be returned. :param page_token: Token specifying the next page of results. It should be obtained from a ``list_experiments`` call. :return: A :py:class:`PagedList <mlflow.store.entities.PagedList>` of :py:class:`Experiment <mlflow.entities.Experiment>` objects. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ from mlflow.utils.search_utils import SearchUtils from mlflow.store.entities.paged_list import PagedList _validate_list_experiments_max_results(max_results) self._check_root_dir() rsl = [] if view_type == ViewType.ACTIVE_ONLY or view_type == ViewType.ALL: rsl += self._get_active_experiments(full_path=False) if view_type == ViewType.DELETED_ONLY or view_type == ViewType.ALL: rsl += self._get_deleted_experiments(full_path=False) experiments = [] for exp_id in rsl: try: # trap and warn known issues, will raise unexpected exceptions to caller experiment = self._get_experiment(exp_id, view_type) if experiment: experiments.append(experiment) except MissingConfigException as rnfe: # Trap malformed experiments and log warnings. logging.warning( "Malformed experiment '%s'. Detailed error %s", str(exp_id), str(rnfe), exc_info=True, ) if max_results is not None: experiments, next_page_token = SearchUtils.paginate( experiments, page_token, max_results ) return PagedList(experiments, next_page_token) else: return PagedList(experiments, None)
def _search_runs( self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token, ): if max_results > SEARCH_MAX_RESULTS_THRESHOLD: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format( SEARCH_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE, ) runs = [] for experiment_id in experiment_ids: run_ids = self._list_runs_ids(experiment_id, run_view_type) run_infos = [ _dict_to_run_info(r) for r in self._get_run_list(run_ids) ] for run_info in run_infos: # Load the metrics, params and tags for the run run_id = run_info.run_id metrics = self.get_all_metrics(run_id) params = self.get_all_params(run_id) tags = self.get_all_tags(run_id) run = Run(run_info, RunData(metrics, params, tags)) runs.append(run) filtered = SearchUtils.filter(runs, filter_string) sorted_runs = SearchUtils.sort(filtered, order_by) runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token, max_results) return runs, next_page_token
def _get_attributes_filtering_clauses(parsed): clauses = [] for sql_statement in parsed: key_type = sql_statement.get('type') key_name = sql_statement.get('key') value = sql_statement.get('value') comparator = sql_statement.get('comparator') if SearchUtils.is_attribute(key_type, comparator): # validity of the comparator is checked in SearchUtils.parse_search_filter() op = SearchUtils.filter_ops.get(comparator) if op: # key_name is guaranteed to be a valid searchable attribute of entities.RunInfo # by the call to parse_search_filter attribute_name = SqlRun.get_attribute_name(key_name) clauses.append(op(getattr(SqlRun, attribute_name), value)) return clauses
def _parse_search_registered_models_order_by(cls, order_by_list): """Sorts a set of registered models based on their natural ordering and an overriding set of order_bys. Registered models are naturally ordered first by name ascending. """ clauses = [] observed_order_by_clauses = set() if order_by_list: for order_by_clause in order_by_list: ( attribute_token, ascending, ) = SearchUtils.parse_order_by_for_search_registered_models( order_by_clause) if attribute_token == SqlRegisteredModel.name.key: field = SqlRegisteredModel.name elif attribute_token in SearchUtils.VALID_TIMESTAMP_ORDER_BY_KEYS: field = SqlRegisteredModel.last_updated_time else: raise MlflowException( "Invalid order by key '{}' specified.".format( attribute_token) + "Valid keys are " + "'{}'".format( SearchUtils. RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS), error_code=INVALID_PARAMETER_VALUE, ) if field.key in observed_order_by_clauses: raise MlflowException( "`order_by` contains duplicate fields: {}".format( order_by_list)) observed_order_by_clauses.add(field.key) if ascending: clauses.append(field.asc()) else: clauses.append(field.desc()) if SqlRegisteredModel.name.key not in observed_order_by_clauses: clauses.append(SqlRegisteredModel.name.asc()) return clauses
def _search_runs( self, experiment_ids: List[str], filter_string: str, run_view_type: str, max_results: int = SEARCH_MAX_RESULTS_DEFAULT, order_by: List[str] = None, page_token: str = None, columns_to_whitelist: List[str] = None) -> Tuple[List[Run], str]: if max_results > 10000: raise MlflowException( "Invalid value for request parameter max_results. It must be at " "most {}, but got value {}".format(10000, max_results), INVALID_PARAMETER_VALUE) stages = LifecycleStage.view_type_to_stages(run_view_type) parsed_filters = SearchUtils.parse_search_filter(filter_string) filter_queries = [ Q("match", experiment_id=experiment_ids[0]), Q("terms", lifecycle_stage=stages) ] filter_queries += self._build_elasticsearch_query(parsed_filters) sort_clauses = self._get_orderby_clauses(order_by) s = Search(index="mlflow-runs").query('bool', filter=filter_queries) s = s.sort(*sort_clauses) if page_token != "" and page_token is not None: s = s.extra(search_after=ast.literal_eval(page_token)) response = s.params(size=max_results).execute() columns_to_whitelist_key_dict = self._build_columns_to_whitelist_key_dict( columns_to_whitelist) runs = [ self._hit_to_mlflow_run(hit, columns_to_whitelist_key_dict) for hit in response ] if len(runs) == max_results: next_page_token = response.hits.hits[-1].sort else: next_page_token = [] return runs, str(next_page_token)
def test_invalid_page_tokens(page_token, error_message): with pytest.raises(MlflowException) as e: SearchUtils.paginate([], page_token, 1) assert error_message in e.value.message
def test_invalid_order_by(order_by, error_message): with pytest.raises(MlflowException) as e: SearchUtils._parse_order_by(order_by) assert error_message in e.value.message
def test_invalid_clauses(filter_string, error_message): with pytest.raises(MlflowException) as e: SearchUtils.parse_search_filter(filter_string) assert error_message in e.value.message