コード例 #1
0
ファイル: test_search_utils.py プロジェクト: zxf1864/mlflow
def test_anded_expression():
    se = SearchExpression(metric=MetricSearchExpression(
        key="accuracy", double=DoubleClause(comparator=">=", value=.94)))
    sf = SearchFilter(anded_expressions=[se])
    assert sf._parse() == [{
        "type": "metric",
        "key": "accuracy",
        "comparator": ">=",
        "value": 0.94
    }]
コード例 #2
0
def test_search_filter_basics():
    search_filter = "This is a filter string"
    anded_expressions = [SearchExpression(), SearchExpression()]

    # only anded_expressions
    SearchFilter(anded_expressions=anded_expressions)

    # only search filter
    SearchFilter(filter_string=search_filter)

    # both
    with pytest.raises(MlflowException) as e:
        SearchFilter(anded_expressions=anded_expressions, filter_string=search_filter)
        assert e.message.contains("Can specify only one of 'filter' or 'search_expression'")
コード例 #3
0
def test_bad_comparators(entity_type, bad_comparators, key, entity_value):
    run = Run(run_info=RunInfo(
        run_uuid="hi", run_id="hi", experiment_id=0,
        user_id="user-id", status=RunStatus.to_string(RunStatus.FAILED),
        start_time=0, end_time=1, lifecycle_stage=LifecycleStage.ACTIVE),
        run_data=RunData(metrics=[], params=[], tags=[])
    )
    for bad_comparator in bad_comparators:
        bad_filter = "{entity_type}.{key} {comparator} {value}".format(
            entity_type=entity_type, key=key, comparator=bad_comparator, value=entity_value)
        sf = SearchFilter(filter_string=bad_filter)
        with pytest.raises(MlflowException) as e:
            sf.filter(run)
        assert "Invalid comparator" in str(e.value.message)
コード例 #4
0
def test_anded_expression_2():
    m1 = MetricSearchExpression(key="accuracy",
                                double=DoubleClause(comparator=">=",
                                                    value=.94))
    m2 = MetricSearchExpression(key="error",
                                double=DoubleClause(comparator="<", value=.01))
    m3 = MetricSearchExpression(key="mse",
                                float=FloatClause(comparator=">=", value=5))
    p1 = ParameterSearchExpression(key="a",
                                   string=StringClause(comparator="=",
                                                       value="0"))
    p2 = ParameterSearchExpression(key="b",
                                   string=StringClause(comparator="!=",
                                                       value="blah"))
    sf = SearchFilter(
        SearchRuns(anded_expressions=[
            SearchExpression(metric=m1),
            SearchExpression(metric=m2),
            SearchExpression(metric=m3),
            SearchExpression(parameter=p1),
            SearchExpression(parameter=p2)
        ]))

    assert sf._parse() == [{
        'comparator': '>=',
        'key': 'accuracy',
        'type': 'metric',
        'value': 0.94
    }, {
        'comparator': '<',
        'key': 'error',
        'type': 'metric',
        'value': 0.01
    }, {
        'comparator': '>=',
        'key': 'mse',
        'type': 'metric',
        'value': 5
    }, {
        'comparator': '=',
        'key': 'a',
        'type': 'parameter',
        'value': '0'
    }, {
        'comparator': '!=',
        'key': 'b',
        'type': 'parameter',
        'value': 'blah'
    }]
コード例 #5
0
 def _search(self, experiment_id, metrics_expressions=None, param_expressions=None,
             run_view_type=ViewType.ALL):
     search_runs = SearchRuns()
     search_runs.anded_expressions.extend(metrics_expressions or [])
     search_runs.anded_expressions.extend(param_expressions or [])
     search_filter = SearchFilter(search_runs)
     return [r.info.run_uuid
             for r in self.store.search_runs([experiment_id], search_filter, run_view_type)]
コード例 #6
0
 def _search(self,
             fs,
             experiment_id,
             filter_str=None,
             run_view_type=ViewType.ALL,
             max_results=SEARCH_MAX_RESULTS_DEFAULT):
     search_filter = SearchFilter(
         filter_string=filter_str) if filter_str else None
     return [
         r.info.run_id for r in fs.search_runs(
             [experiment_id], search_filter, run_view_type, max_results)
     ]
コード例 #7
0
ファイル: client.py プロジェクト: zxf1864/mlflow
    def search_runs(self, experiment_ids, filter_string, run_view_type=ViewType.ACTIVE_ONLY):
        """
        Search experiments that fit the search criteria.

        :param experiment_ids: List of experiment IDs
        :param filter_string: Filter query string.
        :param run_view_type: one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL runs
                              defined in :py:class:`mlflow.entities.ViewType`.
        :return:
        """
        return self.store.search_runs(experiment_ids=experiment_ids,
                                      search_filter=SearchFilter(filter_string=filter_string),
                                      run_view_type=run_view_type)
コード例 #8
0
ファイル: handlers.py プロジェクト: tpartyka/mlflow
def _search_runs():
    request_message = _get_request_message(SearchRuns())
    response_message = SearchRuns.Response()
    run_view_type = ViewType.ACTIVE_ONLY
    if request_message.HasField('run_view_type'):
        run_view_type = ViewType.from_proto(request_message.run_view_type)
    run_entities = _get_store().search_runs(request_message.experiment_ids,
                                            SearchFilter(request_message),
                                            run_view_type)
    response_message.runs.extend([r.to_proto() for r in run_entities])
    response = Response(mimetype='application/json')
    response.set_data(message_to_json(response_message))
    return response
コード例 #9
0
ファイル: test_search_utils.py プロジェクト: zxf1864/mlflow
def test_bad_comparators(entity_type, bad_comparators, entity_value):
    run = Run(run_info=RunInfo(run_uuid="hi",
                               experiment_id=0,
                               name="name",
                               source_type=SourceType.PROJECT,
                               source_name="source-name",
                               entry_point_name="entry-point-name",
                               user_id="user-id",
                               status=RunStatus.FAILED,
                               start_time=0,
                               end_time=1,
                               source_version="version",
                               lifecycle_stage=LifecycleStage.ACTIVE),
              run_data=RunData(metrics=[], params=[], tags=[]))
    for bad_comparator in bad_comparators:
        bad_filter = "{entity_type}.abc {comparator} {value}".format(
            entity_type=entity_type,
            comparator=bad_comparator,
            value=entity_value)
        sf = SearchFilter(filter_string=bad_filter)
        with pytest.raises(MlflowException) as e:
            sf.filter(run)
        assert "Invalid comparator" in str(e.value.message)
コード例 #10
0
ファイル: handlers.py プロジェクト: yuecong/mlflow
def _search_runs():
    request_message = _get_request_message(SearchRuns())
    response_message = SearchRuns.Response()
    run_view_type = ViewType.ACTIVE_ONLY
    if request_message.HasField('run_view_type'):
        run_view_type = ViewType.from_proto(request_message.run_view_type)
    sf = SearchFilter(anded_expressions=request_message.anded_expressions,
                      filter_string=request_message.filter)
    max_results = request_message.max_results
    experiment_ids = request_message.experiment_ids
    run_entities = _get_store().search_runs(experiment_ids, sf, run_view_type, max_results)
    response_message.runs.extend([r.to_proto() for r in run_entities])
    response = Response(mimetype='application/json')
    response.set_data(message_to_json(response_message))
    return response
コード例 #11
0
    def search_runs(self, experiment_ids, filter_string,
                    run_view_type=ViewType.ACTIVE_ONLY,
                    max_results=SEARCH_MAX_RESULTS_DEFAULT):
        """
        Search experiments that fit the search criteria.

        :param experiment_ids: List of experiment IDs
        :param filter_string: Filter query string.
        :param run_view_type: one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL runs
                              defined in :py:class:`mlflow.entities.ViewType`.
        :param max_results: Maximum number of runs desired.

        :return: A list of :py:class:`mlflow.entities.Run` objects that satisfy the search
            expressions
        """
        return self.store.search_runs(experiment_ids=experiment_ids,
                                      search_filter=SearchFilter(filter_string=filter_string),
                                      run_view_type=run_view_type,
                                      max_results=max_results)
コード例 #12
0
ファイル: test_search_utils.py プロジェクト: zxf1864/mlflow
def test_invalid_clauses(filter_string, error_message):
    with pytest.raises(MlflowException) as e:
        SearchFilter(filter_string=filter_string)._parse()
    assert error_message in e.value.message
コード例 #13
0
def test_filter(filter_string, parsed_filter):
    assert SearchFilter(
        SearchRuns(filter=filter_string))._parse() == parsed_filter
コード例 #14
0
ファイル: test_search_utils.py プロジェクト: zxf1864/mlflow
def test_correct_quote_trimming(filter_string, parsed_filter):
    assert SearchFilter(filter_string=filter_string)._parse() == parsed_filter
コード例 #15
0
ファイル: test_search_utils.py プロジェクト: zxf1864/mlflow
def test_filter(filter_string, parsed_filter):
    assert SearchFilter(filter_string=filter_string)._parse() == parsed_filter
コード例 #16
0
def test_error_filter(filter_string, error_message):
    with pytest.raises(MlflowException) as e:
        SearchFilter(SearchRuns(filter=filter_string))._parse()
    assert error_message in e.value.message
コード例 #17
0
 def _search_with_filter_string(self, fs, experiment_id, filter_str, run_view_type=ViewType.ALL):
     search_filter = SearchFilter(filter_string=filter_str)
     return [r.info.run_uuid
             for r in fs.search_runs([experiment_id], search_filter, run_view_type)]