コード例 #1
0
ファイル: test_client.py プロジェクト: jeffmaxey/faculty
def test_experiment_client_list_runs(mocker):
    mocker.patch.object(ExperimentClient, "query_runs")

    client = ExperimentClient(mocker.Mock(), mocker.Mock())
    response = client.list_runs(
        PROJECT_ID,
        experiment_ids=[123, 456],
        lifecycle_stage=LifecycleStage.DELETED,
        start=20,
        limit=10,
    )

    assert response == ExperimentClient.query_runs.return_value
    expected_filter = CompoundFilter(
        LogicalOperator.AND,
        [
            CompoundFilter(
                LogicalOperator.OR,
                [
                    ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 123),
                    ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 456),
                ],
            ),
            DeletedAtFilter(ComparisonOperator.DEFINED, True),
        ],
    )
    ExperimentClient.query_runs.assert_called_once_with(
        PROJECT_ID, expected_filter, None, 20, 10)
コード例 #2
0
def test_filter_schema_nested():
    filter = CompoundFilter(
        LogicalOperator.AND,
        [
            CompoundFilter(LogicalOperator.AND,
                           [PROJECT_ID_FILTER, TAG_FILTER]),
            CompoundFilter(LogicalOperator.OR,
                           [TAG_FILTER, PROJECT_ID_FILTER]),
        ],
    )
    data = _FilterSchema().dump(filter)
    assert data == {
        "operator":
        "and",
        "conditions": [
            {
                "operator": "and",
                "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY],
            },
            {
                "operator": "or",
                "conditions": [TAG_FILTER_BODY, PROJECT_ID_FILTER_BODY],
            },
        ],
    }
コード例 #3
0
def test_build_search_runs_filter_no_experiment_ids(mocker):
    experiment_ids = None
    view_type = mocker.Mock()
    filter_string = "param.alpha > 0.2"

    mocker.patch("mlflow_faculty.filter._filter_by_experiment_id")
    mocker.patch("mlflow_faculty.filter._filter_by_mlflow_view_type")
    mocker.patch("mlflow_faculty.filter._parse_filter_string")

    expected_filter = CompoundFilter(
        LogicalOperator.AND,
        [
            mlflow_faculty.filter._filter_by_mlflow_view_type.return_value,
            mlflow_faculty.filter._parse_filter_string.return_value,
        ],
    )

    filter = build_search_runs_filter(experiment_ids, filter_string, view_type)
    assert filter == expected_filter

    mlflow_faculty.filter._filter_by_experiment_id.assert_not_called()
    mlflow_faculty.filter._filter_by_mlflow_view_type.assert_called_once_with(
        view_type
    )
    mlflow_faculty.filter._parse_filter_string.assert_called_once_with(
        filter_string
    )
コード例 #4
0
ファイル: test_client.py プロジェクト: jeffmaxey/faculty
def test_restore_runs(mocker):
    restore_runs_response = mocker.Mock()
    mocker.patch.object(ExperimentClient,
                        "_post",
                        return_value=restore_runs_response)
    response_schema_mock = mocker.patch(
        "faculty.clients.experiment._RestoreExperimentRunsResponseSchema")
    filter_schema_mock = mocker.patch(
        "faculty.clients.experiment._FilterSchema")
    filter_dump_mock = filter_schema_mock.return_value.dump

    run_ids = [uuid4(), uuid4()]

    client = ExperimentClient(mocker.Mock(), mocker.Mock())
    response = client.restore_runs(PROJECT_ID, run_ids)

    assert response == restore_runs_response

    expected_filter = CompoundFilter(
        LogicalOperator.OR,
        [
            RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[0]),
            RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[1]),
        ],
    )
    filter_dump_mock.assert_called_once_with(expected_filter)
    ExperimentClient._post.assert_called_once_with(
        "/project/{}/run/restore/query".format(PROJECT_ID),
        response_schema_mock.return_value,
        json={"filter": filter_dump_mock.return_value},
    )
コード例 #5
0
def test_filter_schema_compound(operator, expected_operator):
    filter = CompoundFilter(operator, [PROJECT_ID_FILTER, TAG_FILTER])
    data = _FilterSchema().dump(filter)
    assert data == {
        "operator": expected_operator,
        "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY],
    }
コード例 #6
0
def test_parse_filter_string_logical_operator(
    sql_operator,
    expected_operator,
    left_string,
    left_filter,
    right_string,
    right_filter,
):
    filter_string = "{} {} {}".format(left_string, sql_operator, right_string)
    expected_filter = CompoundFilter(
        expected_operator, [left_filter, right_filter]
    )
    filter = _parse_filter_string(filter_string)
    assert filter == expected_filter
    assert isinstance(filter, type(expected_filter))
コード例 #7
0
def _parse_token_list(tokens):
    """Parse a list of sqlparse Tokens and return an equivalent filter."""

    # Ignore whitespace chars
    tokens = [t for t in tokens if not t.is_whitespace]

    if any(_is_or(t) for t in tokens):
        filters = []
        for part in _split_list(tokens, _is_or):
            filters.append(_parse_token_list(part))
        return CompoundFilter(LogicalOperator.OR, filters)

    elif any(_is_and(t) for t in tokens):
        filters = []
        for part in _split_list(tokens, _is_and):
            filters.append(_parse_token_list(part))
        return CompoundFilter(LogicalOperator.AND, filters)

    elif len(tokens) == 1:
        [token] = tokens
        if isinstance(token, SqlParenthesis):
            # Strip opening and closing parentheses
            return _parse_token_list(token.tokens[1:-1])
        elif isinstance(token, SqlComparison):
            return _parse_token_list(token.tokens)
        else:
            raise ValueError(
                "Unsupported filter string component: {!r}".format(
                    token.normalized))

    elif len(tokens) == 3:
        return _single_filter_from_tokens(*tokens)

    else:
        raise ValueError("Unsupported filter string component: {!r}".format(
            " ".join(t.normalized for t in tokens)))
コード例 #8
0
def _filter_by_experiment_id(experiment_ids):
    """Build a filter that a run is in one of a sequence of experiment IDs."""

    if len(experiment_ids) == 0:
        # Cannot build a filter for this
        raise MatchesNothing()

    parts = [
        ExperimentIdFilter(ComparisonOperator.EQUAL_TO, int(experiment_id))
        for experiment_id in experiment_ids
    ]

    if len(parts) == 1:
        return parts[0]
    else:
        return CompoundFilter(LogicalOperator.OR, parts)
コード例 #9
0
def build_search_runs_filter(experiment_ids, filter_string, view_type):
    """Build a filter from the inputs to search_runs in the tracking store."""

    filter_parts = []

    if experiment_ids is not None:
        filter_parts.append(_filter_by_experiment_id(experiment_ids))

    deleted_at_filter = _filter_by_mlflow_view_type(view_type)
    if deleted_at_filter is not None:
        filter_parts.append(deleted_at_filter)

    if filter_string is not None and filter_string.strip() != "":
        filter_parts.append(_parse_filter_string(filter_string))

    if len(filter_parts) == 0:
        return None
    elif len(filter_parts) == 1:
        return filter_parts[0]
    else:
        return CompoundFilter(LogicalOperator.AND, filter_parts)
コード例 #10
0
    mlflow_faculty.filter._filter_by_mlflow_view_type.assert_called_once_with(
        view_type
    )
    mlflow_faculty.filter._parse_filter_string.assert_not_called()


@pytest.mark.parametrize(
    "experiment_ids, expected_filter",
    [
        ([1], ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 1)),
        (
            [1, 2],
            CompoundFilter(
                LogicalOperator.OR,
                [
                    ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 1),
                    ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 2),
                ],
            ),
        ),
        (
            ["3", "4"],
            CompoundFilter(
                LogicalOperator.OR,
                [
                    ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 3),
                    ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 4),
                ],
            ),
        ),
    ],