def test_get_paginated_runs_lt_maxresults_multipage(): """ Number of runs is less than max_results, but multiple pages are necessary to get all runs """ tokenized_runs = PagedList([create_run() for i in range(10)], "token") no_token_runs = PagedList([create_run()], "") max_results = 50 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 10): with mock.patch.object(MlflowClient, "search_runs"): MlflowClient.search_runs.side_effect = [ tokenized_runs, tokenized_runs, no_token_runs ] TOTAL_RUNS = 21 paginated_runs = _get_paginated_runs([], "", ViewType.ACTIVE_ONLY, max_results, None) assert len(paginated_runs) == TOTAL_RUNS
def test_get_paginated_runs_eq_maxresults_token(): """ Runs returned are equal to max_results which are equal to a full number of pages. The server might send a token back, or they might not (depending on if they know if more runs exist). In this example, a toke IS sent back. Expected behavior is to NOT query for more pages. """ runs = [create_run() for i in range(10)] tokenized_runs = PagedList(runs, "abc") blank_runs = PagedList([], "") max_results = 10 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 10): with mock.patch.object(MlflowClient, "search_runs"): MlflowClient.search_runs.side_effect = [tokenized_runs, blank_runs] paginated_runs = _get_paginated_runs([], "", ViewType.ACTIVE_ONLY, max_results, None) MlflowClient.search_runs.assert_called_once() assert len(paginated_runs) == 10
def test_search_runs_default_view_type(mock_get_request_message, mock_store): """ Search Runs default view type is filled in as ViewType.ACTIVE_ONLY """ mock_get_request_message.return_value = SearchRuns(experiment_ids=["0"]) mock_store.search_runs.return_value = PagedList([], None) _search_runs() args, _ = mock_store.search_runs.call_args assert args[2] == ViewType.ACTIVE_ONLY
def test_get_paginated_runs_gt_maxresults_multipage(): """ Number of runs that fit search criteria is greater than max_results. Multiple pages expected. Expected to only get max_results number of results back. """ # should ask for and return the correct number of max_results full_page_runs = PagedList([create_run() for i in range(8)], "abc") partial_page = PagedList([create_run() for i in range(4)], "def") max_results = 20 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 8): with mock.patch.object(MlflowClient, "search_runs"): MlflowClient.search_runs.side_effect = [ full_page_runs, full_page_runs, partial_page ] paginated_runs = _get_paginated_runs([12], "", ViewType.ACTIVE_ONLY, max_results, None) calls = [ mock.call([12], "", ViewType.ACTIVE_ONLY, 8, None, None), mock.call([12], "", ViewType.ACTIVE_ONLY, 8, None, "abc"), mock.call([12], "", ViewType.ACTIVE_ONLY, 20 % 8, None, "abc") ] MlflowClient.search_runs.assert_has_calls(calls) assert len(paginated_runs) == 20
def test_get_paginated_runs_lt_maxresults_onepage_nonetoken(): """ Number of runs is less than max_results and fits on one page. The token passed back on the last page is None, not the emptystring """ runs = [create_run() for i in range(5)] tokenized_runs = PagedList(runs, None) max_results = 50 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 10): with mock.patch.object(MlflowClient, "search_runs", return_value=tokenized_runs): paginated_runs = _get_paginated_runs([], "", ViewType.ACTIVE_ONLY, max_results, None) MlflowClient.search_runs.assert_called_once() assert len(paginated_runs) == 5
def test_get_paginated_runs_gt_maxresults_onepage(): """" Number of runs that fit search criteria is greater than max_results. Only one page expected. Expected to only get max_results number of results back. """ runs = [create_run() for i in range(10)] tokenized_runs = PagedList(runs, "abc") max_results = 10 with mock.patch("mlflow.tracking.fluent.NUM_RUNS_PER_PAGE_PANDAS", 20): with mock.patch.object(MlflowClient, "search_runs", return_value=tokenized_runs): paginated_runs = _get_paginated_runs([123], "", ViewType.ACTIVE_ONLY, max_results, None) MlflowClient.search_runs.assert_called_once_with( [123], "", ViewType.ACTIVE_ONLY, max_results, None, None) assert len(paginated_runs) == 10