Пример #1
0
    def search_registered_models(
        self,
        filter_string=None,
        max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
        order_by=None,
        page_token=None,
    ):
        """
        Search for registered models in backend that satisfy the filter criteria.

        :param filter_string: Filter query string, defaults to searching all registered models.
        :param max_results: Maximum number of registered models desired.
        :param order_by: List of column names with ASC|DESC annotation, to be used for ordering
                         matching search results.
        :param page_token: Token specifying the next page of results. It should be obtained from
                            a ``search_registered_models`` call.
        :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
                that satisfy the search expressions. The pagination token for the next page can be
                obtained via the ``token`` attribute of the object.
        """
        if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD:
            raise MlflowException(
                "Invalid value for request parameter max_results. "
                "It must be at most {}, but got value {}".format(
                    SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
                    max_results),
                INVALID_PARAMETER_VALUE,
            )

        parsed_filter = SearchUtils.parse_filter_for_registered_models(
            filter_string)
        parsed_orderby = self._parse_search_registered_models_order_by(
            order_by)
        offset = SearchUtils.parse_start_offset_from_page_token(page_token)
        # we query for max_results + 1 items to check whether there is another page to return.
        # this remediates having to make another query which returns no items.
        max_results_for_query = max_results + 1

        def compute_next_token(current_size):
            next_token = None
            if max_results_for_query == current_size:
                final_offset = offset + max_results
                next_token = SearchUtils.create_page_token(final_offset)
            return next_token

        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            comparator = filter_dict["comparator"].upper()
            if comparator not in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS:
                raise MlflowException(
                    "Search registered models filter expression only "
                    "supports the equality(=) comparator, case-sensitive"
                    "partial match (LIKE), and case-insensitive partial "
                    "match (ILIKE). Input filter string: %s" % filter_string,
                    error_code=INVALID_PARAMETER_VALUE,
                )
            if comparator == SearchUtils.LIKE_OPERATOR:
                conditions = [
                    SqlRegisteredModel.name.like(filter_dict["value"])
                ]
            elif comparator == SearchUtils.ILIKE_OPERATOR:
                conditions = [
                    SqlRegisteredModel.name.ilike(filter_dict["value"])
                ]
            else:
                conditions = [SqlRegisteredModel.name == filter_dict["value"]]
        else:
            supported_ops = "".join([
                "(" + op + ")"
                for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS
            ])
            sample_query = 'name {} "<model_name>"'.format(supported_ops)
            raise MlflowException(
                "Invalid filter string: {}".format(filter_string) +
                "Search registered models supports filter expressions like:" +
                sample_query,
                error_code=INVALID_PARAMETER_VALUE,
            )
        with self.ManagedSessionMaker() as session:
            query = (session.query(SqlRegisteredModel).filter(
                *conditions).order_by(
                    *parsed_orderby).limit(max_results_for_query))
            if page_token:
                query = query.offset(offset)
            sql_registered_models = query.all()
            next_page_token = compute_next_token(len(sql_registered_models))
            rm_entities = [
                rm.to_mlflow_entity() for rm in sql_registered_models
            ][:max_results]
            return PagedList(rm_entities, next_page_token)
Пример #2
0
    def search_registered_models(
            self,
            filter_string=None,
            max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
            order_by=None,
            page_token=None):
        """
        Search for registered models in backend that satisfy the filter criteria.

        :param filter_string: Filter query string, defaults to searching all registered models.
                              Currently supports a single filter condition based on
                              the name of the model like ``name = 'model_name'``
        :param max_results: Maximum number of registered models desired.
        :param order_by: List of column names with ASC|DESC annotation, to be used for ordering
                         matching search results.
                         Note:: This field is currently not supported.
        :param page_token: Token specifying the next page of results. It should be obtained from
                            a ``search_registered_models`` call.
        :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
                that satisfy the search expressions. The pagination token for the next page can be
                obtained via the ``token`` attribute of the object.
        """
        if order_by:
            raise NotImplementedError(
                "Order by is not implemented for search registered models.")
        if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD:
            raise MlflowException(
                "Invalid value for request parameter max_results."
                "It must be at most {}, but got value {}".format(
                    SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
                    max_results), INVALID_PARAMETER_VALUE)

        parsed_filter = SearchUtils.parse_filter_for_registered_models(
            filter_string)
        offset = SearchUtils.parse_start_offset_from_page_token(page_token)
        # we query for max_results + 1 items to check whether there is another page to return.
        # this remediates having to make another query which returns no items.
        max_results_for_query = max_results + 1

        def compute_next_token(current_size):
            next_token = None
            if max_results_for_query == current_size:
                final_offset = offset + max_results
                next_token = SearchUtils.create_page_token(final_offset)
            return next_token

        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            comparator = filter_dict['comparator'].upper()
            if comparator not in \
                    SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS:
                raise MlflowException(
                    'Search registered models filter expression only '
                    'supports the equality(=) comparator, case-sensitive'
                    'partial match (LIKE), and case-insensitive partial '
                    'match (ILIKE). Input filter string: %s' % filter_string,
                    error_code=INVALID_PARAMETER_VALUE)
            if comparator == SearchUtils.LIKE_OPERATOR:
                conditions = [
                    SqlRegisteredModel.name.like(filter_dict["value"])
                ]
            elif comparator == SearchUtils.ILIKE_OPERATOR:
                conditions = [
                    SqlRegisteredModel.name.ilike(filter_dict["value"])
                ]
            else:
                conditions = [SqlRegisteredModel.name == filter_dict["value"]]
        else:
            supported_ops = ''.join([
                '(' + op + ')'
                for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS
            ])
            sample_query = f'name {supported_ops} "<model_name>"'
            raise MlflowException(
                f'Invalid filter string: {filter_string}'
                'Search registered models supports filter expressions like:' +
                sample_query,
                error_code=INVALID_PARAMETER_VALUE)
        with self.ManagedSessionMaker() as session:
            if self.db_type == SQLITE:
                session.execute("PRAGMA case_sensitive_like = true;")
            query = session\
                .query(SqlRegisteredModel)\
                .filter(*conditions)\
                .order_by(SqlRegisteredModel.name.asc())\
                .limit(max_results_for_query)
            if page_token:
                query = query.offset(offset)
            sql_registered_models = query.all()
            next_page_token = compute_next_token(len(sql_registered_models))
            rm_entities = [
                rm.to_mlflow_entity() for rm in sql_registered_models
            ][:max_results]
            return PagedList(rm_entities, next_page_token)