def test_databricks_rest_store_get_experiment_by_name(self):
        creds = MlflowHostCreds("https://hello")
        store = DatabricksRestStore(lambda: creds)
        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            # Verify that Databricks REST client won't fall back to ListExperiments for 500-level
            # errors that are not ENDPOINT_NOT_FOUND

            def rate_limit_response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                raise MlflowException("Some internal error!", INTERNAL_ERROR)

            mock_http.side_effect = rate_limit_response_fn
            with pytest.raises(MlflowException) as exc_info:
                store.get_experiment_by_name("abc")
            assert exc_info.value.error_code == ErrorCode.Name(INTERNAL_ERROR)
            assert exc_info.value.message == "Some internal error!"
            expected_message0 = GetExperimentByName(experiment_name="abc")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/get-by-name",
                "GET",
                message_to_json(expected_message0),
            )
            assert mock_http.call_count == 1
Exemplo n.º 2
0
    def test_databricks_paginate_list_experiments(self):
        creds = MlflowHostCreds("https://hello")
        store = DatabricksRestStore(lambda: creds)

        list_exp_responses = []
        next_page_tokens = ["a", "b", None]
        for next_page_token in next_page_tokens:
            experiment = Experiment(
                experiment_id="123",
                name=str(next_page_token),
                artifact_location="/abc",
                lifecycle_stage=LifecycleStage.ACTIVE,
            )
            list_exp_response = mock.MagicMock()
            list_exp_response.text = json.dumps({
                "experiments":
                [json.loads(message_to_json(experiment.to_proto()))],
                "next_page_token":
                next_page_token,
            })
            list_exp_response.status_code = 200
            list_exp_responses.append(list_exp_response)

        with mock.patch("mlflow.utils.rest_utils.http_request",
                        side_effect=list_exp_responses):
            for idx, experiments in enumerate(
                    store._paginate_list_experiments(ViewType.ACTIVE_ONLY)):
                assert experiments[0].name == str(next_page_tokens[idx])
                assert experiments.token == next_page_tokens[idx]
Exemplo n.º 3
0
def _get_databricks_rest_store(store_uri, **_):
    return DatabricksRestStore(lambda: get_databricks_host_creds(store_uri))
Exemplo n.º 4
0
def _get_databricks_rest_store(store_uri, **_):
    profile = get_db_profile_from_uri(store_uri)
    return DatabricksRestStore(lambda: get_databricks_host_creds(profile))
Exemplo n.º 5
0
def _get_databricks_rest_store(store_uri, **_):
    return DatabricksRestStore(partial(get_databricks_host_creds, store_uri))