Esempio n. 1
0
def _to_sqlalchemy_filtering_statement(sql_statement, session):
    key_type = sql_statement.get('type')
    key_name = sql_statement.get('key')
    value = sql_statement.get('value')
    comparator = sql_statement.get('comparator').upper()

    if SearchUtils.is_metric(key_type, comparator):
        entity = SqlLatestMetric
        value = float(value)
    elif SearchUtils.is_param(key_type, comparator):
        entity = SqlParam
    elif SearchUtils.is_tag(key_type, comparator):
        entity = SqlTag
    elif SearchUtils.is_attribute(key_type, comparator):
        return None
    else:
        raise MlflowException("Invalid search expression type '%s'" % key_type,
                              error_code=INVALID_PARAMETER_VALUE)

    if comparator in SearchUtils.CASE_INSENSITIVE_STRING_COMPARISON_OPERATORS:
        op = SearchUtils.get_sql_filter_ops(entity.value, comparator)
        return (session.query(entity).filter(entity.key == key_name,
                                             op(value)).subquery())
    elif comparator in SearchUtils.filter_ops:
        op = SearchUtils.filter_ops.get(comparator)
        return (session.query(entity).filter(entity.key == key_name,
                                             op(entity.value,
                                                value)).subquery())
    else:
        return None
Esempio n. 2
0
    def _search_runs(self, experiment_ids, filter_string, run_view_type,
                     max_results, order_by, page_token):
        def compute_next_token(current_size):
            next_token = None
            if max_results == current_size:
                final_offset = offset + max_results
                next_token = SearchUtils.create_page_token(final_offset)

            return next_token

        if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
            raise MlflowException(
                "Invalid value for request parameter max_results. It must be at "
                "most {}, but got value {}".format(
                    SEARCH_MAX_RESULTS_THRESHOLD, max_results),
                INVALID_PARAMETER_VALUE)

        stages = set(LifecycleStage.view_type_to_stages(run_view_type))

        with self.ManagedSessionMaker() as session:
            # Fetch the appropriate runs and eagerly load their summary metrics, params, and
            # tags. These run attributes are referenced during the invocation of
            # ``run.to_mlflow_entity()``, so eager loading helps avoid additional database queries
            # that are otherwise executed at attribute access time under a lazy loading model.
            parsed_filters = SearchUtils.parse_search_filter(filter_string)
            parsed_orderby, sorting_joins = _get_orderby_clauses(
                order_by, session)

            query = session.query(SqlRun)
            for j in _get_sqlalchemy_filter_clauses(parsed_filters, session):
                query = query.join(j)
            # using an outer join is necessary here because we want to be able to sort
            # on a column (tag, metric or param) without removing the lines that
            # do not have a value for this column (which is what inner join would do)
            for j in sorting_joins:
                query = query.outerjoin(j)

            offset = SearchUtils.parse_start_offset_from_page_token(page_token)
            queried_runs = query.distinct() \
                .options(*self._get_eager_run_query_options()) \
                .filter(
                    SqlRun.experiment_id.in_(experiment_ids),
                    SqlRun.lifecycle_stage.in_(stages),
                    *_get_attributes_filtering_clauses(parsed_filters)) \
                .order_by(*parsed_orderby) \
                .offset(offset).limit(max_results).all()

            runs = [run.to_mlflow_entity() for run in queried_runs]
            next_page_token = compute_next_token(len(runs))

        return runs, next_page_token
Esempio n. 3
0
    def _parse_search_registered_models_order_by(cls, order_by_list):
        """Sorts a set of registered models based on their natural ordering and an overriding set
        of order_bys. Registered models are naturally ordered first by name ascending.
        """
        clauses = []
        if order_by_list:
            for order_by_clause in order_by_list:
                attribute_token, ascending = \
                    SearchUtils.parse_order_by_for_search_registered_models(order_by_clause)
                if attribute_token == SqlRegisteredModel.name.key:
                    field = SqlRegisteredModel.name
                elif attribute_token in SearchUtils.VALID_TIMESTAMP_ORDER_BY_KEYS:
                    field = SqlRegisteredModel.last_updated_time
                else:
                    raise MlflowException(
                        f"Invalid order by key '{attribute_token}' specified."
                        f"Valid keys are "
                        f"'{SearchUtils.RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS}'",
                        error_code=INVALID_PARAMETER_VALUE)
                if ascending:
                    clauses.append(field.asc())
                else:
                    clauses.append(field.desc())

        clauses.append(SqlRegisteredModel.name.asc())
        return clauses
Esempio n. 4
0
        def compute_next_token(current_size):
            next_token = None
            if max_results == current_size:
                final_offset = offset + max_results
                next_token = SearchUtils.create_page_token(final_offset)

            return next_token
Esempio n. 5
0
 def _search_runs(self, experiment_ids, filter_string, run_view_type,
                  max_results, order_by, page_token):
     if max_results > SEARCH_MAX_RESULTS_THRESHOLD:
         raise MlflowException(
             "Invalid value for request parameter max_results. It must be at "
             "most {}, but got value {}".format(
                 SEARCH_MAX_RESULTS_THRESHOLD, max_results),
             databricks_pb2.INVALID_PARAMETER_VALUE)
     runs = []
     for experiment_id in experiment_ids:
         run_infos = self._list_run_infos(experiment_id, run_view_type)
         runs.extend(self._get_run_from_info(r) for r in run_infos)
     filtered = SearchUtils.filter(runs, filter_string)
     sorted_runs = SearchUtils.sort(filtered, order_by)
     runs, next_page_token = SearchUtils.paginate(sorted_runs, page_token,
                                                  max_results)
     return runs, next_page_token
Esempio n. 6
0
def _get_attributes_filtering_clauses(parsed):
    clauses = []
    for sql_statement in parsed:
        key_type = sql_statement.get('type')
        key_name = sql_statement.get('key')
        value = sql_statement.get('value')
        comparator = sql_statement.get('comparator').upper()
        if SearchUtils.is_attribute(key_type, comparator):
            # key_name is guaranteed to be a valid searchable attribute of entities.RunInfo
            # by the call to parse_search_filter
            attribute = getattr(SqlRun, SqlRun.get_attribute_name(key_name))
            if comparator in SearchUtils.CASE_INSENSITIVE_STRING_COMPARISON_OPERATORS:
                op = SearchUtils.get_sql_filter_ops(attribute, comparator)
                clauses.append(op(value))
            elif comparator in SearchUtils.filter_ops:
                op = SearchUtils.filter_ops.get(comparator)
                clauses.append(op(attribute, value))
    return clauses
Esempio n. 7
0
    def search_model_versions(self, filter_string):
        """
        Search for model versions in backend that satisfy the filter criteria.

        :param filter_string: A filter string expression. Currently supports a single filter
                              condition either name of model like ``name = 'model_name'`` or
                              ``run_id = '...'``.
        :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
                 objects.
        """
        parsed_filter = SearchUtils.parse_filter_for_model_versions(
            filter_string)
        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            if filter_dict["comparator"] != "=":
                raise MlflowException(
                    'Model Registry search filter only supports equality(=) '
                    'comparator. Input filter string: %s' % filter_string,
                    error_code=INVALID_PARAMETER_VALUE)
            if filter_dict["key"] == "name":
                conditions = [SqlModelVersion.name == filter_dict["value"]]
            elif filter_dict["key"] == "source_path":
                conditions = [SqlModelVersion.source == filter_dict["value"]]
            elif filter_dict["key"] == "run_id":
                conditions = [SqlModelVersion.run_id == filter_dict["value"]]
            else:
                raise MlflowException('Invalid filter string: %s' %
                                      filter_string,
                                      error_code=INVALID_PARAMETER_VALUE)
        else:
            raise MlflowException(
                'Model Registry expects filter to be one of '
                '"name = \'<model_name>\'" or '
                '"source_path = \'<source_path>\'" or "run_id = \'<run_id>\'.'
                'Input filter string: %s. ' % filter_string,
                error_code=INVALID_PARAMETER_VALUE)

        with self.ManagedSessionMaker() as session:
            conditions.append(
                SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL)
            sql_model_version = session.query(SqlModelVersion).filter(
                *conditions).all()
            model_versions = [
                mv.to_mlflow_entity() for mv in sql_model_version
            ]
            return PagedList(model_versions, None)
Esempio n. 8
0
    def search_registered_models(
            self,
            filter_string=None,
            max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
            order_by=None,
            page_token=None):
        """
        Search for registered models in backend that satisfy the filter criteria.

        :param filter_string: Filter query string, defaults to searching all registered models.
        :param max_results: Maximum number of registered models desired.
        :param order_by: List of column names with ASC|DESC annotation, to be used for ordering
                         matching search results.
        :param page_token: Token specifying the next page of results. It should be obtained from
                            a ``search_registered_models`` call.
        :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
                that satisfy the search expressions. The pagination token for the next page can be
                obtained via the ``token`` attribute of the object.
        """
        if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD:
            raise MlflowException(
                "Invalid value for request parameter max_results. "
                "It must be at most {}, but got value {}".format(
                    SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
                    max_results), INVALID_PARAMETER_VALUE)

        parsed_filter = SearchUtils.parse_filter_for_registered_models(
            filter_string)
        parsed_orderby = self._parse_search_registered_models_order_by(
            order_by)
        offset = SearchUtils.parse_start_offset_from_page_token(page_token)
        # we query for max_results + 1 items to check whether there is another page to return.
        # this remediates having to make another query which returns no items.
        max_results_for_query = max_results + 1

        def compute_next_token(current_size):
            next_token = None
            if max_results_for_query == current_size:
                final_offset = offset + max_results
                next_token = SearchUtils.create_page_token(final_offset)
            return next_token

        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            comparator = filter_dict['comparator'].upper()
            if comparator not in \
                    SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS:
                raise MlflowException(
                    'Search registered models filter expression only '
                    'supports the equality(=) comparator, case-sensitive'
                    'partial match (LIKE), and case-insensitive partial '
                    'match (ILIKE). Input filter string: %s' % filter_string,
                    error_code=INVALID_PARAMETER_VALUE)
            if comparator == SearchUtils.LIKE_OPERATOR:
                conditions = [
                    SqlRegisteredModel.name.like(filter_dict["value"])
                ]
            elif comparator == SearchUtils.ILIKE_OPERATOR:
                conditions = [
                    SqlRegisteredModel.name.ilike(filter_dict["value"])
                ]
            else:
                conditions = [SqlRegisteredModel.name == filter_dict["value"]]
        else:
            supported_ops = ''.join([
                '(' + op + ')'
                for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS
            ])
            sample_query = f'name {supported_ops} "<model_name>"'
            raise MlflowException(
                f'Invalid filter string: {filter_string}'
                'Search registered models supports filter expressions like:' +
                sample_query,
                error_code=INVALID_PARAMETER_VALUE)
        with self.ManagedSessionMaker() as session:
            query = session\
                .query(SqlRegisteredModel)\
                .filter(*conditions)\
                .order_by(*parsed_orderby)\
                .limit(max_results_for_query)
            if page_token:
                query = query.offset(offset)
            sql_registered_models = query.all()
            next_page_token = compute_next_token(len(sql_registered_models))
            rm_entities = [
                rm.to_mlflow_entity() for rm in sql_registered_models
            ][:max_results]
            return PagedList(rm_entities, next_page_token)
Esempio n. 9
0
def _get_orderby_clauses(order_by_list, session):
    """Sorts a set of runs based on their natural ordering and an overriding set of order_bys.
    Runs are naturally ordered first by start time descending, then by run id for tie-breaking.
    """

    clauses = []
    ordering_joins = []
    clause_id = 0
    # contrary to filters, it is not easily feasible to separately handle sorting
    # on attributes and on joined tables as we must keep all clauses in the same order
    if order_by_list:
        for order_by_clause in order_by_list:
            clause_id += 1
            (key_type, key, ascending
             ) = SearchUtils.parse_order_by_for_search_runs(order_by_clause)
            if SearchUtils.is_attribute(key_type, '='):
                order_value = getattr(SqlRun, SqlRun.get_attribute_name(key))
            else:
                if SearchUtils.is_metric(key_type,
                                         '='):  # any valid comparator
                    entity = SqlLatestMetric
                elif SearchUtils.is_tag(key_type, '='):
                    entity = SqlTag
                elif SearchUtils.is_param(key_type, '='):
                    entity = SqlParam
                else:
                    raise MlflowException("Invalid identifier type '%s'" %
                                          key_type,
                                          error_code=INVALID_PARAMETER_VALUE)

                # build a subquery first because we will join it in the main request so that the
                # metric we want to sort on is available when we apply the sorting clause
                subquery = session \
                    .query(entity) \
                    .filter(entity.key == key) \
                    .subquery()

                ordering_joins.append(subquery)
                order_value = subquery.c.value

            # sqlite does not support NULLS LAST expression, so we sort first by
            # presence of the field (and is_nan for metrics), then by actual value
            # As the subqueries are created independently and used later in the
            # same main query, the CASE WHEN columns need to have unique names to
            # avoid ambiguity
            if SearchUtils.is_metric(key_type, '='):
                clauses.append(
                    sql.case([(subquery.c.is_nan.is_(True), 1),
                              (order_value.is_(None), 1)],
                             else_=0).label('clause_%s' % clause_id))
            else:  # other entities do not have an 'is_nan' field
                clauses.append(
                    sql.case([(order_value.is_(None), 1)],
                             else_=0).label('clause_%s' % clause_id))

            if ascending:
                clauses.append(order_value)
            else:
                clauses.append(order_value.desc())

    clauses.append(SqlRun.start_time.desc())
    clauses.append(SqlRun.run_uuid)
    return clauses, ordering_joins