def search_registered_models(self, filter_string): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: A filter string expression. Currently supports a single filter condition either name of model like ``name = 'model_name'`` :return: List of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects. """ parsed_filter = SearchUtils.parse_filter_for_model_registry( filter_string) if len(parsed_filter) == 0: conditions = [] elif len(parsed_filter) == 1: filter_dict = parsed_filter[0] if filter_dict["comparator"] not in \ SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS: raise MlflowException( 'Search registered models filter expression only ' 'supports the equality(=) comparator, case-sensitive' 'partial match (LIKE), and case-insensitive partial ' 'match (ILIKE). Input filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) if filter_dict["key"] == "name": if filter_dict["comparator"] == "LIKE": conditions = [ SqlRegisteredModel.name.like(filter_dict["value"]) ] elif filter_dict["comparator"] == "ILIKE": conditions = [ SqlRegisteredModel.name.ilike(filter_dict["value"]) ] else: conditions = [ SqlRegisteredModel.name == filter_dict["value"] ] else: raise MlflowException('Invalid filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) else: supported_ops = ''.join([ '(' + op + ')' for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS ]) sample_query = f'name {supported_ops} "<model_name>"' raise MlflowException( f'Invalid filter string: {filter_string}' 'Search registered models supports filter expressions like:' + sample_query, error_code=INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: if self.db_type == SQLITE: session.execute("PRAGMA case_sensitive_like = true;") sql_registered_models = session.query(SqlRegisteredModel).filter( *conditions).all() registered_models = [ rm.to_mlflow_entity() for rm in sql_registered_models ] return registered_models
def search_model_versions(self, filter_string): """ Search for model versions in backend that satisfy the filter criteria. :param filter_string: A filter string expression. Currently supports a single filter condition either name of model like ``name = 'model_name'`` or ``run_id = '...'``. :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion` objects. """ parsed_filter = SearchUtils.parse_filter_for_model_registry( filter_string) if len(parsed_filter) == 0: conditions = [] elif len(parsed_filter) == 1: filter_dict = parsed_filter[0] if filter_dict["comparator"] != "=": raise MlflowException( 'Model Registry search filter only supports equality(=) ' 'comparator. Input filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) if filter_dict["key"] == "name": conditions = [SqlModelVersion.name == filter_dict["value"]] elif filter_dict["key"] == "source_path": conditions = [SqlModelVersion.source == filter_dict["value"]] elif filter_dict["key"] == "run_id": conditions = [SqlModelVersion.run_id == filter_dict["value"]] else: raise MlflowException('Invalid filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) else: raise MlflowException( 'Model Registry expects filter to be one of ' '"name = \'<model_name>\'" or ' '"source_path = \'<source_path>\'" or "run_id = \'<run_id>\'.' 'Input filter string: %s. ' % filter_string, error_code=INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: conditions.append( SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL) sql_model_version = session.query(SqlModelVersion).filter( *conditions).all() model_versions = [ mv.to_mlflow_entity() for mv in sql_model_version ] return PagedList(model_versions, None)
def search_registered_models(self, filter_string, page_token=None, max_results=SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT): """ Search for registered models in backend that satisfy the filter criteria. :param filter_string: A filter string expression. Currently supports a single filter condition either name of model like ``name = 'model_name'`` :param page_token: Token specifying the next page of results. It should be obtained from a ``search_registered_models`` call. :param max_results: Maximum number of registered models desired. :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects that satisfy the search expressions. The pagination token for the next page can be obtained via the ``token`` attribute of the object. """ if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD: raise MlflowException("Invalid value for request parameter max_results." "It must be at most {}, but got value {}" .format(SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD, max_results), INVALID_PARAMETER_VALUE) parsed_filter = SearchUtils.parse_filter_for_model_registry(filter_string) offset = SearchUtils.parse_start_offset_from_page_token(page_token) def compute_next_token(current_size): next_token = None if max_results == current_size: final_offset = offset + max_results next_token = SearchUtils.create_page_token(final_offset) return next_token if len(parsed_filter) == 0: conditions = [] elif len(parsed_filter) == 1: filter_dict = parsed_filter[0] if filter_dict["comparator"] not in \ SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS: raise MlflowException('Search registered models filter expression only ' 'supports the equality(=) comparator, case-sensitive' 'partial match (LIKE), and case-insensitive partial ' 'match (ILIKE). Input filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) if filter_dict["key"] == "name": if filter_dict["comparator"] == "LIKE": conditions = [SqlRegisteredModel.name.like(filter_dict["value"])] elif filter_dict["comparator"] == "ILIKE": conditions = [SqlRegisteredModel.name.ilike(filter_dict["value"])] else: conditions = [SqlRegisteredModel.name == filter_dict["value"]] else: raise MlflowException('Invalid filter string: %s' % filter_string, error_code=INVALID_PARAMETER_VALUE) else: supported_ops = ''.join(['(' + op + ')' for op in SearchUtils.VALID_REGISTERED_MODEL_SEARCH_COMPARATORS]) sample_query = f'name {supported_ops} "<model_name>"' raise MlflowException(f'Invalid filter string: {filter_string}' 'Search registered models supports filter expressions like:' + sample_query, error_code=INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: if self.db_type == SQLITE: session.execute("PRAGMA case_sensitive_like = true;") query = session\ .query(SqlRegisteredModel)\ .filter(*conditions)\ .order_by(SqlRegisteredModel.name.asc())\ .limit(max_results) if page_token: query = query.offset(offset) sql_registered_models = query.all() registered_models = [rm.to_mlflow_entity() for rm in sql_registered_models] next_page_token = compute_next_token(len(registered_models)) return PagedList(registered_models, next_page_token)