Ejemplo n.º 1
0
    def search_registered_models(self, filter_string):
        """
        Search for registered models in backend that satisfy the filter criteria.

        :param filter_string: A filter string expression. Currently supports a single filter
                              condition either name of model like ``name = 'model_name'``

        :return: List of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects.
        """
        parsed_filter = SearchUtils.parse_filter_for_model_registry(
            filter_string)
        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            if filter_dict["comparator"] not in \
                    SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS:
                raise MlflowException(
                    'Search registered models filter expression only '
                    'supports the equality(=) comparator, case-sensitive'
                    'partial match (LIKE), and case-insensitive partial '
                    'match (ILIKE). Input filter string: %s' % filter_string,
                    error_code=INVALID_PARAMETER_VALUE)
            if filter_dict["key"] == "name":
                if filter_dict["comparator"] == "LIKE":
                    conditions = [
                        SqlRegisteredModel.name.like(filter_dict["value"])
                    ]
                elif filter_dict["comparator"] == "ILIKE":
                    conditions = [
                        SqlRegisteredModel.name.ilike(filter_dict["value"])
                    ]
                else:
                    conditions = [
                        SqlRegisteredModel.name == filter_dict["value"]
                    ]
            else:
                raise MlflowException('Invalid filter string: %s' %
                                      filter_string,
                                      error_code=INVALID_PARAMETER_VALUE)
        else:
            supported_ops = ''.join([
                '(' + op + ')'
                for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS
            ])
            sample_query = f'name {supported_ops} "<model_name>"'
            raise MlflowException(
                f'Invalid filter string: {filter_string}'
                'Search registered models supports filter expressions like:' +
                sample_query,
                error_code=INVALID_PARAMETER_VALUE)
        with self.ManagedSessionMaker() as session:
            if self.db_type == SQLITE:
                session.execute("PRAGMA case_sensitive_like = true;")
            sql_registered_models = session.query(SqlRegisteredModel).filter(
                *conditions).all()
            registered_models = [
                rm.to_mlflow_entity() for rm in sql_registered_models
            ]
            return registered_models
Ejemplo n.º 2
0
    def search_model_versions(self, filter_string):
        """
        Search for model versions in backend that satisfy the filter criteria.

        :param filter_string: A filter string expression. Currently supports a single filter
                              condition either name of model like ``name = 'model_name'`` or
                              ``run_id = '...'``.

        :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
                 objects.
        """
        parsed_filter = SearchUtils.parse_filter_for_model_registry(
            filter_string)
        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            if filter_dict["comparator"] != "=":
                raise MlflowException(
                    'Model Registry search filter only supports equality(=) '
                    'comparator. Input filter string: %s' % filter_string,
                    error_code=INVALID_PARAMETER_VALUE)
            if filter_dict["key"] == "name":
                conditions = [SqlModelVersion.name == filter_dict["value"]]
            elif filter_dict["key"] == "source_path":
                conditions = [SqlModelVersion.source == filter_dict["value"]]
            elif filter_dict["key"] == "run_id":
                conditions = [SqlModelVersion.run_id == filter_dict["value"]]
            else:
                raise MlflowException('Invalid filter string: %s' %
                                      filter_string,
                                      error_code=INVALID_PARAMETER_VALUE)
        else:
            raise MlflowException(
                'Model Registry expects filter to be one of '
                '"name = \'<model_name>\'" or '
                '"source_path = \'<source_path>\'" or "run_id = \'<run_id>\'.'
                'Input filter string: %s. ' % filter_string,
                error_code=INVALID_PARAMETER_VALUE)

        with self.ManagedSessionMaker() as session:
            conditions.append(
                SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL)
            sql_model_version = session.query(SqlModelVersion).filter(
                *conditions).all()
            model_versions = [
                mv.to_mlflow_entity() for mv in sql_model_version
            ]
            return PagedList(model_versions, None)
Ejemplo n.º 3
0
    def search_registered_models(self,
                                 filter_string,
                                 page_token=None,
                                 max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT):
        """
        Search for registered models in backend that satisfy the filter criteria.

        :param filter_string: A filter string expression. Currently supports a single filter
                              condition either name of model like ``name = 'model_name'``
        :param page_token: Token specifying the next page of results. It should be obtained from
                            a ``search_registered_models`` call.
        :param max_results: Maximum number of registered models desired.
        :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
                that satisfy the search expressions. The pagination token for the next page can be
                obtained via the ``token`` attribute of the object.
        """
        if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD:
            raise MlflowException("Invalid value for request parameter max_results."
                                  "It must be at most {}, but got value {}"
                                  .format(SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
                                          max_results),
                                  INVALID_PARAMETER_VALUE)

        parsed_filter = SearchUtils.parse_filter_for_model_registry(filter_string)
        offset = SearchUtils.parse_start_offset_from_page_token(page_token)

        def compute_next_token(current_size):
            next_token = None
            if max_results == current_size:
                final_offset = offset + max_results
                next_token = SearchUtils.create_page_token(final_offset)
            return next_token

        if len(parsed_filter) == 0:
            conditions = []
        elif len(parsed_filter) == 1:
            filter_dict = parsed_filter[0]
            if filter_dict["comparator"] not in \
                    SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS:
                raise MlflowException('Search registered models filter expression only '
                                      'supports the equality(=) comparator, case-sensitive'
                                      'partial match (LIKE), and case-insensitive partial '
                                      'match (ILIKE). Input filter string: %s' % filter_string,
                                      error_code=INVALID_PARAMETER_VALUE)
            if filter_dict["key"] == "name":
                if filter_dict["comparator"] == "LIKE":
                    conditions = [SqlRegisteredModel.name.like(filter_dict["value"])]
                elif filter_dict["comparator"] == "ILIKE":
                    conditions = [SqlRegisteredModel.name.ilike(filter_dict["value"])]
                else:
                    conditions = [SqlRegisteredModel.name == filter_dict["value"]]
            else:
                raise MlflowException('Invalid filter string: %s' % filter_string,
                                      error_code=INVALID_PARAMETER_VALUE)
        else:
            supported_ops = ''.join(['(' + op + ')' for op in
                                     SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS])
            sample_query = f'name {supported_ops} "<model_name>"'
            raise MlflowException(f'Invalid filter string: {filter_string}'
                                  'Search registered models supports filter expressions like:' +
                                  sample_query, error_code=INVALID_PARAMETER_VALUE)
        with self.ManagedSessionMaker() as session:
            if self.db_type == SQLITE:
                session.execute("PRAGMA case_sensitive_like = true;")
            query = session\
                .query(SqlRegisteredModel)\
                .filter(*conditions)\
                .order_by(SqlRegisteredModel.name.asc())\
                .limit(max_results)
            if page_token:
                query = query.offset(offset)
            sql_registered_models = query.all()
            registered_models = [rm.to_mlflow_entity() for rm in sql_registered_models]
            next_page_token = compute_next_token(len(registered_models))
            return PagedList(registered_models, next_page_token)