示例#1
0
def _to_sqlalchemy_filtering_statement(sql_statement, session):
    key_type = sql_statement.get('type')
    key_name = sql_statement.get('key')
    value = sql_statement.get('value')
    comparator = sql_statement.get('comparator')

    if SearchUtils.is_metric(key_type, comparator):
        entity = SqlLatestMetric
        value = float(value)
    elif SearchUtils.is_param(key_type, comparator):
        entity = SqlParam
    elif SearchUtils.is_tag(key_type, comparator):
        entity = SqlTag
    elif SearchUtils.is_attribute(key_type, comparator):
        return None
    else:
        raise MlflowException("Invalid search expression type '%s'" % key_type,
                              error_code=INVALID_PARAMETER_VALUE)

    # validity of the comparator is checked in SearchUtils.parse_search_filter()
    op = SearchUtils.filter_ops.get(comparator)
    if op:
        return (session.query(entity).filter(entity.key == key_name,
                                             op(entity.value,
                                                value)).subquery())
    else:
        return None
示例#2
0
def test_order_by_metric_with_nans_infs_nones():
    metric_vals_str = ["nan", "inf", "-inf", "-1000", "0", "1000", "None"]
    runs = [
        Run(
            run_info=RunInfo(
                run_id=x,
                run_uuid=x,
                experiment_id=0,
                user_id="user",
                status=RunStatus.to_string(RunStatus.FINISHED),
                start_time=0,
                end_time=1,
                lifecycle_stage=LifecycleStage.ACTIVE,
            ),
            run_data=RunData(
                metrics=[Metric("x", None if x == "None" else float(x), 1, 0)
                         ]),
        ) for x in metric_vals_str
    ]
    sorted_runs_asc = [
        x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x asc"])
    ]
    sorted_runs_desc = [
        x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x desc"])
    ]
    # asc
    assert ["-inf", "-1000", "0", "1000", "inf", "nan",
            "None"] == sorted_runs_asc
    # desc
    assert ["inf", "1000", "0", "-1000", "-inf", "nan",
            "None"] == sorted_runs_desc
def _to_sqlalchemy_filtering_statement(sql_statement, session):
    key_type = sql_statement.get("type")
    key_name = sql_statement.get("key")
    value = sql_statement.get("value")
    comparator = sql_statement.get("comparator").upper()

    if SearchUtils.is_metric(key_type, comparator):
        entity = SqlLatestMetric
        value = float(value)
    elif SearchUtils.is_param(key_type, comparator):
        entity = SqlParam
    elif SearchUtils.is_tag(key_type, comparator):
        entity = SqlTag
    elif SearchUtils.is_attribute(key_type, comparator):
        return None
    else:
        raise MlflowException("Invalid search expression type '%s'" % key_type,
                              error_code=INVALID_PARAMETER_VALUE)

    if comparator in SearchUtils.CASE_INSENSITIVE_STRING_COMPARISON_OPERATORS:
        op = SearchUtils.get_sql_filter_ops(entity.value, comparator)
        return session.query(entity).filter(entity.key == key_name,
                                            op(value)).subquery()
    elif comparator in SearchUtils.filter_ops:
        op = SearchUtils.filter_ops.get(comparator)
        return (session.query(entity).filter(entity.key == key_name,
                                             op(entity.value,
                                                value)).subquery())
    else:
        return None
示例#4
0
    def _search_runs(self, experiment_ids, filter_string, run_view_type,
                     max_results, order_by, page_token):
        # TODO: push search query into backend database layer
        if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
            raise MlflowException(
                "Invalid value for request parameter max_results. It must be at "
                "most {}, but got value {}".format(
                    SEARCH_MAX_RESULTS_THRESHOLD, max_results),
                INVALID_PARAMETER_VALUE)

        stages = set(LifecycleStage.view_type_to_stages(run_view_type))
        with self.ManagedSessionMaker() as session:
            # Fetch the appropriate runs and eagerly load their summary metrics, params, and
            # tags. These run attributes are referenced during the invocation of
            # ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries
            # that are otherwise executed at attribute access time under a lazy loading model.
            queried_runs = session \
                .query(SqlRun) \
                .options(*self._get_eager_run_query_options()) \
                .filter(
                    SqlRun.experiment_id.in_(experiment_ids),
                    SqlRun.lifecycle_stage.in_(stages)) \
                .all()
            runs = [run.to_mlflow_entity() for run in queried_runs]

        filtered = SearchUtils.filter(runs, filter_string)
        sorted_runs = SearchUtils.sort(filtered, order_by)
        runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token,
                                                     max_results)
        return runs, next_page_token
示例#5
0
    def _search_runs(
        self, experiment_ids, filter_string, run_view_type, max_results, order_by, page_token
    ):
        def compute_next_token(current_size):
            next_token = None
            if max_results == current_size:
                final_offset = offset + max_results
                next_token = SearchUtils.create_page_token(final_offset)

            return next_token

        if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
            raise MlflowException(
                "Invalid value for request parameter max_results. It must be at "
                "most {}, but got value {}".format(SEARCH_MAX_RESULTS_THRESHOLD, max_results),
                INVALID_PARAMETER_VALUE,
            )

        stages = set(LifecycleStage.view_type_to_stages(run_view_type))

        with self.ManagedSessionMaker() as session:
            # Fetch the appropriate runs and eagerly load their summary metrics, params, and
            # tags. These run attributes are referenced during the invocation of
            # ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries
            # that are otherwise executed at attribute access time under a lazy loading model.
            parsed_filters = SearchUtils.parse_search_filter(filter_string)
            parsed_orderby, sorting_joins = _get_orderby_clauses(order_by, session)

            query = session.query(SqlRun)
            for j in _get_sqlalchemy_filter_clauses(parsed_filters, session):
                query = query.join(j)
            # using an outer join is necessary here because we want to be able to sort
            # on a column (tag, metric or param) without removing the lines that
            # do not have a value for this column (which is what inner join would do)
            for j in sorting_joins:
                query = query.outerjoin(j)

            offset = SearchUtils.parse_start_offset_from_page_token(page_token)
            queried_runs = (
                query.distinct()
                .options(*self._get_eager_run_query_options())
                .filter(
                    SqlRun.experiment_id.in_(experiment_ids),
                    SqlRun.lifecycle_stage.in_(stages),
                    *_get_attributes_filtering_clauses(parsed_filters)
                )
                .order_by(*parsed_orderby)
                .offset(offset)
                .limit(max_results)
                .all()
            )

            runs = [run.to_mlflow_entity() for run in queried_runs]
            next_page_token = compute_next_token(len(runs))

        return runs, next_page_token
示例#6
0
def test_bad_comparators(entity_type, bad_comparators, key, entity_value):
    run = Run(run_info=RunInfo(
        run_uuid="hi", run_id="hi", experiment_id=0,
        user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
        start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
        run_data=RunData(metrics=[], params=[], tags=[])
    )
    for bad_comparator in bad_comparators:
        bad_filter = "{entity_type}.{key} {comparator} {value}".format(
            entity_type=entity_type, key=key, comparator=bad_comparator, value=entity_value)
        with pytest.raises(MlflowException) as e:
            SearchUtils.filter([run], bad_filter)
        assert "Invalid comparator" in str(e.value.message)
def test_space_order_by_search_runs(order_by, ascending_expected):
    identifier_type, identifier_name, ascending = SearchUtils.parse_order_by_for_search_runs(
        order_by
    )
    assert identifier_type == "metric"
    assert identifier_name == "Mean Square Error"
    assert ascending == ascending_expected
示例#8
0
    def search_registered_models(self, filter_string):
        """
        Search for registered models in backend that satisfy the filter criteria.

        :param filter_string: A filter string expression. Currently supports a single filter
                              condition either name of model like ``name = 'model_name'``

        :return: List of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects.
        """
        parsed_filter = SearchUtils.parse_filter_for_model_registry(
            filter_string)
        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            if filter_dict["comparator"] not in \
                    SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS:
                raise MlflowException(
                    'Search registered models filter expression only '
                    'supports the equality(=) comparator, case-sensitive'
                    'partial match (LIKE), and case-insensitive partial '
                    'match (ILIKE). Input filter string: %s' % filter_string,
                    error_code=INVALID_PARAMETER_VALUE)
            if filter_dict["key"] == "name":
                if filter_dict["comparator"] == "LIKE":
                    conditions = [
                        SqlRegisteredModel.name.like(filter_dict["value"])
                    ]
                elif filter_dict["comparator"] == "ILIKE":
                    conditions = [
                        SqlRegisteredModel.name.ilike(filter_dict["value"])
                    ]
                else:
                    conditions = [
                        SqlRegisteredModel.name == filter_dict["value"]
                    ]
            else:
                raise MlflowException('Invalid filter string: %s' %
                                      filter_string,
                                      error_code=INVALID_PARAMETER_VALUE)
        else:
            supported_ops = ''.join([
                '(' + op + ')'
                for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS
            ])
            sample_query = f'name {supported_ops} "<model_name>"'
            raise MlflowException(
                f'Invalid filter string: {filter_string}'
                'Search registered models supports filter expressions like:' +
                sample_query,
                error_code=INVALID_PARAMETER_VALUE)
        with self.ManagedSessionMaker() as session:
            if self.db_type == SQLITE:
                session.execute("PRAGMA case_sensitive_like = true;")
            sql_registered_models = session.query(SqlRegisteredModel).filter(
                *conditions).all()
            registered_models = [
                rm.to_mlflow_entity() for rm in sql_registered_models
            ]
            return registered_models
示例#9
0
 def _get_orderby_clauses(self, order_by_list: List[str]) -> List[dict]:
     type_dict = {
         "metric": "latest_metrics",
         "parameter": "params",
         "tag": "tags"
     }
     sort_clauses = []
     if order_by_list:
         for order_by_clause in order_by_list:
             (key_type, key, ascending) = SearchUtils. \
                 parse_order_by_for_search_runs(order_by_clause)
             sort_order = "asc" if ascending else "desc"
             if not SearchUtils.is_attribute(key_type, "="):
                 key_type = type_dict[key_type]
                 sort_clauses.append({
                     f'{key_type}.value': {
                         'order': sort_order,
                         "nested": {
                             "path": key_type,
                             "filter": {
                                 "term": {
                                     f'{key_type}.key': key
                                 }
                             }
                         }
                     }
                 })
             else:
                 sort_clauses.append({key: {'order': sort_order}})
     sort_clauses.append({"start_time": {'order': "desc"}})
     sort_clauses.append({"run_id": {'order': "asc"}})
     return sort_clauses
示例#10
0
def test_pagination(page_token, max_results, matching_runs, expected_next_page_token):
    runs = [
        Run(run_info=RunInfo(
            run_uuid="0", run_id="0", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData([], [], [])),
        Run(run_info=RunInfo(
            run_uuid="1", run_id="1", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData([], [], [])),
        Run(run_info=RunInfo(
            run_uuid="2", run_id="2", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData([], [], []))
    ]
    encoded_page_token = None
    if page_token:
        encoded_page_token = base64.b64encode(json.dumps(page_token).encode("utf-8"))
    paginated_runs, next_page_token = SearchUtils.paginate(runs, encoded_page_token, max_results)

    paginated_run_indices = []
    for run in paginated_runs:
        for i, r in enumerate(runs):
            if r == run:
                paginated_run_indices.append(i)
                break
    assert paginated_run_indices == matching_runs

    decoded_next_page_token = None
    if next_page_token:
        decoded_next_page_token = json.loads(base64.b64decode(next_page_token))
    assert decoded_next_page_token == expected_next_page_token
示例#11
0
def test_correct_sorting(order_bys, matching_runs):
    runs = [
        Run(run_info=RunInfo(
            run_uuid="9", run_id="9", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData(
                metrics=[Metric("key1", 121, 1, 0)],
                params=[Param("my_param", "A")],
                tags=[])),
        Run(run_info=RunInfo(
            run_uuid="8", run_id="8", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FINISHED),
            start_time=1, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData(
                metrics=[Metric("key1", 123, 1, 0)],
                params=[Param("my_param", "A")],
                tags=[RunTag("tag1", "C")])),
        Run(run_info=RunInfo(
            run_uuid="7", run_id="7", experiment_id=1,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=1, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData(
                metrics=[Metric("key1", 125, 1, 0)],
                params=[Param("my_param", "B")],
                tags=[RunTag("tag1", "D")])),
    ]
    sorted_runs = SearchUtils.sort(runs, order_bys)
    sorted_run_indices = []
    for run in sorted_runs:
        for i, r in enumerate(runs):
            if r == run:
                sorted_run_indices.append(i)
                break
    assert sorted_run_indices == matching_runs
        def compute_next_token(current_size):
            next_token = None
            if max_results == current_size:
                final_offset = offset + max_results
                next_token = SearchUtils.create_page_token(final_offset)

            return next_token
示例#13
0
def test_correct_filtering(filter_string, matching_runs):
    runs = [
        Run(run_info=RunInfo(
            run_uuid="hi", run_id="hi", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData(
                metrics=[Metric("key1", 121, 1, 0)],
                params=[Param("my_param", "A")],
                tags=[])),
        Run(run_info=RunInfo(
            run_uuid="hi2", run_id="hi2", experiment_id=0,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FINISHED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData(
                metrics=[Metric("key1", 123, 1, 0)],
                params=[Param("my_param", "A")],
                tags=[RunTag("tag1", "C")])),
        Run(run_info=RunInfo(
            run_uuid="hi3", run_id="hi3", experiment_id=1,
            user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
            start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData(
                metrics=[Metric("key1", 125, 1, 0)],
                params=[Param("my_param", "B")],
                tags=[RunTag("tag1", "D")])),
    ]
    filtered_runs = SearchUtils.filter(runs, filter_string)
    assert set(filtered_runs) == set([runs[i] for i in matching_runs])
示例#14
0
    def _parse_search_registered_models_order_by(cls, order_by_list):
        """Sorts a set of registered models based on their natural ordering and an overriding set
        of order_bys. Registered models are naturally ordered first by name ascending.
        """
        clauses = []
        if order_by_list:
            for order_by_clause in order_by_list:
                attribute_token, ascending = \
                    SearchUtils.parse_order_by_for_search_registered_models(order_by_clause)
                if attribute_token == SqlRegisteredModel.name.key:
                    field = SqlRegisteredModel.name
                elif attribute_token in SearchUtils.VALID_TIMESTAMP_ORDER_BY_KEYS:
                    field = SqlRegisteredModel.last_updated_time
                else:
                    raise MlflowException(
                        f"Invalid order by key '{attribute_token}' specified."
                        f"Valid keys are "
                        f"'{SearchUtils.RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS}'",
                        error_code=INVALID_PARAMETER_VALUE)
                if ascending:
                    clauses.append(field.asc())
                else:
                    clauses.append(field.desc())

        clauses.append(SqlRegisteredModel.name.asc())
        return clauses
示例#15
0
 def _search_runs(self, experiment_ids, filter_string, run_view_type,
                  max_results, order_by, page_token):
     if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
         raise MlflowException(
             "Invalid value for request parameter max_results. It must be at "
             "most {}, but got value {}".format(
                 SEARCH_MAX_RESULTS_THRESHOLD, max_results),
             databricks_pb2.INVALID_PARAMETER_VALUE)
     runs = []
     for experiment_id in experiment_ids:
         run_infos = self._list_run_infos(experiment_id, run_view_type)
         runs.extend(self.get_run(r.run_id) for r in run_infos)
     filtered = SearchUtils.filter(runs, filter_string)
     sorted_runs = SearchUtils.sort(filtered, order_by)
     runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token,
                                                  max_results)
     return runs, next_page_token
def _get_attributes_filtering_clauses(parsed):
    clauses = []
    for sql_statement in parsed:
        key_type = sql_statement.get("type")
        key_name = sql_statement.get("key")
        value = sql_statement.get("value")
        comparator = sql_statement.get("comparator").upper()
        if SearchUtils.is_attribute(key_type, comparator):
            # key_name is guaranteed to be a valid searchable attribute of entities.RunInfo
            # by the call to parse_search_filter
            attribute = getattr(SqlRun, SqlRun.get_attribute_name(key_name))
            if comparator in SearchUtils.CASE_INSENSITIVE_STRING_COMPARISON_OPERATORS:
                op = SearchUtils.get_sql_filter_ops(attribute, comparator)
                clauses.append(op(value))
            elif comparator in SearchUtils.filter_ops:
                op = SearchUtils.filter_ops.get(comparator)
                clauses.append(op(attribute, value))
    return clauses
示例#17
0
 def _search_runs(self, experiment_ids, filter_string, run_view_type, max_results, order_by,
                  page_token):
     if page_token:
         raise MlflowException("SQLAlchemy-backed tracking stores do not yet support pagination"
                               "tokens.")
     # TODO: push search query into backend database layer
     if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
         raise MlflowException("Invalid value for request parameter max_results. It must be at "
                               "most {}, but got value {}".format(SEARCH_MAX_RESULTS_THRESHOLD,
                                                                  max_results),
                               INVALID_PARAMETER_VALUE)
     with self.ManagedSessionMaker() as session:
         runs = [run.to_mlflow_entity()
                 for exp in experiment_ids
                 for run in self._list_runs(session, exp, run_view_type)]
         filtered = SearchUtils.filter(runs, filter_string)
         runs = SearchUtils.sort(filtered, order_by)[:max_results]
         return runs, None
示例#18
0
    def _list_experiments(
        self,
        ids=None,
        names=None,
        view_type=ViewType.ACTIVE_ONLY,
        max_results=None,
        page_token=None,
        eager=False,
    ):
        """
        :param eager: If ``True``, eagerly loads each experiments's tags. If ``False``, these tags
                      are not eagerly loaded and will be loaded if/when their corresponding
                      object properties are accessed from a resulting ``SqlExperiment`` object.
        """
        stages = LifecycleStage.view_type_to_stages(view_type)
        conditions = [SqlExperiment.lifecycle_stage.in_(stages)]
        if ids and len(ids) > 0:
            int_ids = [int(eid) for eid in ids]
            conditions.append(SqlExperiment.experiment_id.in_(int_ids))
        if names and len(names) > 0:
            conditions.append(SqlExperiment.name.in_(names))

        max_results_for_query = None
        if max_results is not None:
            max_results_for_query = max_results + 1

            def compute_next_token(current_size):
                next_token = None
                if max_results_for_query == current_size:
                    final_offset = offset + max_results
                    next_token = SearchUtils.create_page_token(final_offset)

                return next_token

        with self.ManagedSessionMaker() as session:
            query_options = self._get_eager_experiment_query_options(
            ) if eager else []
            if max_results is not None:
                offset = SearchUtils.parse_start_offset_from_page_token(
                    page_token)
                queried_experiments = (session.query(SqlExperiment).options(
                    *query_options).order_by(
                        SqlExperiment.experiment_id).filter(
                            *conditions).offset(offset).limit(
                                max_results_for_query).all())
            else:
                queried_experiments = (session.query(SqlExperiment).options(
                    *query_options).filter(*conditions).all())

            experiments = [
                exp.to_mlflow_entity() for exp in queried_experiments
            ]
        if max_results is not None:
            return PagedList(experiments[:max_results],
                             compute_next_token(len(experiments)))
        else:
            return PagedList(experiments, None)
示例#19
0
def test_filter_runs_by_start_time():
    runs = [
        Run(
            run_info=RunInfo(
                run_uuid=run_id,
                run_id=run_id,
                experiment_id=0,
                user_id="user-id",
                status=RunStatus.to_string(RunStatus.FINISHED),
                start_time=idx,
                end_time=1,
                lifecycle_stage=LifecycleStage.ACTIVE,
            ),
            run_data=RunData(),
        ) for idx, run_id in enumerate(["a", "b", "c"])
    ]
    assert SearchUtils.filter(runs, "attribute.start_time >= 0") == runs
    assert SearchUtils.filter(runs, "attribute.start_time > 1") == runs[2:]
    assert SearchUtils.filter(runs, "attribute.start_time = 2") == runs[2:]
示例#20
0
 def search_runs(self,
                 experiment_ids,
                 filter_string,
                 run_view_type,
                 max_results=SEARCH_MAX_RESULTS_THRESHOLD,
                 order_by=None):
     # TODO: push search query into backend database layer
     if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
         raise MlflowException(
             "Invalid value for request parameter max_results. It must be at "
             "most {}, but got value {}".format(
                 SEARCH_MAX_RESULTS_THRESHOLD, max_results),
             INVALID_PARAMETER_VALUE)
     with self.ManagedSessionMaker() as session:
         runs = [
             run.to_mlflow_entity() for exp in experiment_ids
             for run in self._list_runs(session, exp, run_view_type)
         ]
         filtered = SearchUtils.filter(runs, filter_string)
         return SearchUtils.sort(filtered, order_by)[:max_results]
示例#21
0
def _get_orderby_clauses(order_by_list, session):
    """Sorts a set of runs based on their natural ordering and an overriding set of order_bys.
    Runs are naturally ordered first by start time descending, then by run id for tie-breaking.
    """

    clauses = []
    ordering_joins = []
    clause_id = 0
    # contrary to filters, it is not easily feasible to separately handle sorting
    # on attributes and on joined tables as we must keep all clauses in the same order
    if order_by_list:
        for order_by_clause in order_by_list:
            clause_id += 1
            (key_type, key, ascending) = SearchUtils.parse_order_by_for_search_runs(order_by_clause)
            if SearchUtils.is_attribute(key_type, "="):
                order_value = getattr(SqlRun, SqlRun.get_attribute_name(key))
            else:
                if SearchUtils.is_metric(key_type, "="):  # any valid comparator
                    entity = SqlLatestMetric
                elif SearchUtils.is_tag(key_type, "="):
                    entity = SqlTag
                elif SearchUtils.is_param(key_type, "="):
                    entity = SqlParam
                else:
                    raise MlflowException(
                        "Invalid identifier type '%s'" % key_type,
                        error_code=INVALID_PARAMETER_VALUE,
                    )

                # build a subquery first because we will join it in the main request so that the
                # metric we want to sort on is available when we apply the sorting clause
                subquery = session.query(entity).filter(entity.key == key).subquery()

                ordering_joins.append(subquery)
                order_value = subquery.c.value

            # sqlite does not support NULLS LAST expression, so we sort first by
            # presence of the field (and is_nan for metrics), then by actual value
            # As the subqueries are created independently and used later in the
            # same main query, the CASE WHEN columns need to have unique names to
            # avoid ambiguity
            if SearchUtils.is_metric(key_type, "="):
                clauses.append(
                    sql.case(
                        [(subquery.c.is_nan.is_(True), 1), (order_value.is_(None), 1)], else_=0
                    ).label("clause_%s" % clause_id)
                )
            else:  # other entities do not have an 'is_nan' field
                clauses.append(
                    sql.case([(order_value.is_(None), 1)], else_=0).label("clause_%s" % clause_id)
                )

            if ascending:
                clauses.append(order_value)
            else:
                clauses.append(order_value.desc())

    clauses.append(SqlRun.start_time.desc())
    clauses.append(SqlRun.run_uuid)
    return clauses, ordering_joins
示例#22
0
    def search_model_versions(self, filter_string):
        """
        Search for model versions in backend that satisfy the filter criteria.

        :param filter_string: A filter string expression. Currently supports a single filter
                              condition either name of model like ``name = 'model_name'`` or
                              ``run_id = '...'``.
        :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
                 objects.
        """
        parsed_filter = SearchUtils.parse_filter_for_model_versions(
            filter_string)
        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            if filter_dict["comparator"] != "=":
                raise MlflowException(
                    "Model Registry search filter only supports equality(=) "
                    "comparator. Input filter string: %s" % filter_string,
                    error_code=INVALID_PARAMETER_VALUE,
                )
            if filter_dict["key"] == "name":
                conditions = [SqlModelVersion.name == filter_dict["value"]]
            elif filter_dict["key"] == "source_path":
                conditions = [SqlModelVersion.source == filter_dict["value"]]
            elif filter_dict["key"] == "run_id":
                conditions = [SqlModelVersion.run_id == filter_dict["value"]]
            else:
                raise MlflowException("Invalid filter string: %s" %
                                      filter_string,
                                      error_code=INVALID_PARAMETER_VALUE)
        else:
            raise MlflowException(
                "Model Registry expects filter to be one of "
                "\"name = '<model_name>'\" or "
                "\"source_path = '<source_path>'\" or \"run_id = '<run_id>'."
                "Input filter string: %s. " % filter_string,
                error_code=INVALID_PARAMETER_VALUE,
            )

        with self.ManagedSessionMaker() as session:
            conditions.append(
                SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL)
            sql_model_version = session.query(SqlModelVersion).filter(
                *conditions).all()
            model_versions = [
                mv.to_mlflow_entity() for mv in sql_model_version
            ]
            return PagedList(model_versions, None)
示例#23
0
    def list_experiments(
        self,
        view_type=ViewType.ACTIVE_ONLY,
        max_results=None,
        page_token=None,
    ):
        """
        :param view_type: Qualify requested type of experiments.
        :param max_results: If passed, specifies the maximum number of experiments desired. If not
                            passed, all experiments will be returned.
        :param page_token: Token specifying the next page of results. It should be obtained from
                           a ``list_experiments`` call.
        :return: A :py:class:`PagedList <mlflow.store.entities.PagedList>` of
                 :py:class:`Experiment <mlflow.entities.Experiment>` objects. The pagination token
                 for the next page can be obtained via the ``token`` attribute of the object.
        """
        from mlflow.utils.search_utils import SearchUtils
        from mlflow.store.entities.paged_list import PagedList

        _validate_list_experiments_max_results(max_results)
        self._check_root_dir()
        rsl = []
        if view_type == ViewType.ACTIVE_ONLY or view_type == ViewType.ALL:
            rsl += self._get_active_experiments(full_path=False)
        if view_type == ViewType.DELETED_ONLY or view_type == ViewType.ALL:
            rsl += self._get_deleted_experiments(full_path=False)

        experiments = []
        for exp_id in rsl:
            try:
                # trap and warn known issues, will raise unexpected exceptions to caller
                experiment = self._get_experiment(exp_id, view_type)
                if experiment:
                    experiments.append(experiment)
            except MissingConfigException as rnfe:
                # Trap malformed experiments and log warnings.
                logging.warning(
                    "Malformed experiment '%s'. Detailed error %s",
                    str(exp_id),
                    str(rnfe),
                    exc_info=True,
                )
        if max_results is not None:
            experiments, next_page_token = SearchUtils.paginate(
                experiments, page_token, max_results
            )
            return PagedList(experiments, next_page_token)
        else:
            return PagedList(experiments, None)
示例#24
0
    def _search_runs(
        self,
        experiment_ids,
        filter_string,
        run_view_type,
        max_results,
        order_by,
        page_token,
    ):
        if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
            raise MlflowException(
                "Invalid value for request parameter max_results. It must be at "
                "most {}, but got value {}".format(
                    SEARCH_MAX_RESULTS_THRESHOLD, max_results),
                INVALID_PARAMETER_VALUE,
            )
        runs = []
        for experiment_id in experiment_ids:
            run_ids = self._list_runs_ids(experiment_id, run_view_type)
            run_infos = [
                _dict_to_run_info(r) for r in self._get_run_list(run_ids)
            ]
            for run_info in run_infos:
                # Load the metrics, params and tags for the run
                run_id = run_info.run_id
                metrics = self.get_all_metrics(run_id)
                params = self.get_all_params(run_id)
                tags = self.get_all_tags(run_id)
                run = Run(run_info, RunData(metrics, params, tags))
                runs.append(run)

        filtered = SearchUtils.filter(runs, filter_string)
        sorted_runs = SearchUtils.sort(filtered, order_by)
        runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token,
                                                     max_results)
        return runs, next_page_token
示例#25
0
def _get_attributes_filtering_clauses(parsed):
    clauses = []
    for sql_statement in parsed:
        key_type = sql_statement.get('type')
        key_name = sql_statement.get('key')
        value = sql_statement.get('value')
        comparator = sql_statement.get('comparator')
        if SearchUtils.is_attribute(key_type, comparator):
            # validity of the comparator is checked in SearchUtils.parse_search_filter()
            op = SearchUtils.filter_ops.get(comparator)
            if op:
                # key_name is guaranteed to be a valid searchable attribute of entities.RunInfo
                # by the call to parse_search_filter
                attribute_name = SqlRun.get_attribute_name(key_name)
                clauses.append(op(getattr(SqlRun, attribute_name), value))
    return clauses
示例#26
0
    def _parse_search_registered_models_order_by(cls, order_by_list):
        """Sorts a set of registered models based on their natural ordering and an overriding set
        of order_bys. Registered models are naturally ordered first by name ascending.
        """
        clauses = []
        observed_order_by_clauses = set()
        if order_by_list:
            for order_by_clause in order_by_list:
                (
                    attribute_token,
                    ascending,
                ) = SearchUtils.parse_order_by_for_search_registered_models(
                    order_by_clause)
                if attribute_token == SqlRegisteredModel.name.key:
                    field = SqlRegisteredModel.name
                elif attribute_token in SearchUtils.VALID_TIMESTAMP_ORDER_BY_KEYS:
                    field = SqlRegisteredModel.last_updated_time
                else:
                    raise MlflowException(
                        "Invalid order by key '{}' specified.".format(
                            attribute_token) + "Valid keys are " +
                        "'{}'".format(
                            SearchUtils.
                            RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS),
                        error_code=INVALID_PARAMETER_VALUE,
                    )
                if field.key in observed_order_by_clauses:
                    raise MlflowException(
                        "`order_by` contains duplicate fields: {}".format(
                            order_by_list))
                observed_order_by_clauses.add(field.key)
                if ascending:
                    clauses.append(field.asc())
                else:
                    clauses.append(field.desc())

        if SqlRegisteredModel.name.key not in observed_order_by_clauses:
            clauses.append(SqlRegisteredModel.name.asc())
        return clauses
示例#27
0
    def _search_runs(
            self,
            experiment_ids: List[str],
            filter_string: str,
            run_view_type: str,
            max_results: int = SEARCH_MAX_RESULTS_DEFAULT,
            order_by: List[str] = None,
            page_token: str = None,
            columns_to_whitelist: List[str] = None) -> Tuple[List[Run], str]:

        if max_results > 10000:
            raise MlflowException(
                "Invalid value for request parameter max_results. It must be at "
                "most {}, but got value {}".format(10000, max_results),
                INVALID_PARAMETER_VALUE)
        stages = LifecycleStage.view_type_to_stages(run_view_type)
        parsed_filters = SearchUtils.parse_search_filter(filter_string)
        filter_queries = [
            Q("match", experiment_id=experiment_ids[0]),
            Q("terms", lifecycle_stage=stages)
        ]
        filter_queries += self._build_elasticsearch_query(parsed_filters)
        sort_clauses = self._get_orderby_clauses(order_by)
        s = Search(index="mlflow-runs").query('bool', filter=filter_queries)
        s = s.sort(*sort_clauses)
        if page_token != "" and page_token is not None:
            s = s.extra(search_after=ast.literal_eval(page_token))
        response = s.params(size=max_results).execute()
        columns_to_whitelist_key_dict = self._build_columns_to_whitelist_key_dict(
            columns_to_whitelist)
        runs = [
            self._hit_to_mlflow_run(hit, columns_to_whitelist_key_dict)
            for hit in response
        ]
        if len(runs) == max_results:
            next_page_token = response.hits.hits[-1].sort
        else:
            next_page_token = []
        return runs, str(next_page_token)
示例#28
0
def test_invalid_page_tokens(page_token, error_message):
    with pytest.raises(MlflowException) as e:
        SearchUtils.paginate([], page_token, 1)
    assert error_message in e.value.message
示例#29
0
def test_invalid_order_by(order_by, error_message):
    with pytest.raises(MlflowException) as e:
        SearchUtils._parse_order_by(order_by)
    assert error_message in e.value.message
示例#30
0
def test_invalid_clauses(filter_string, error_message):
    with pytest.raises(MlflowException) as e:
        SearchUtils.parse_search_filter(filter_string)
    assert error_message in e.value.message