Exemplo n.º 1
0
    def test_databricks_paginate_list_experiments(self):
        creds = MlflowHostCreds("https://hello")
        store = DatabricksRestStore(lambda: creds)

        list_exp_responses = []
        next_page_tokens = ["a", "b", None]
        for next_page_token in next_page_tokens:
            experiment = Experiment(
                experiment_id="123",
                name=str(next_page_token),
                artifact_location="/abc",
                lifecycle_stage=LifecycleStage.ACTIVE,
            )
            list_exp_response = mock.MagicMock()
            list_exp_response.text = json.dumps({
                "experiments":
                [json.loads(message_to_json(experiment.to_proto()))],
                "next_page_token":
                next_page_token,
            })
            list_exp_response.status_code = 200
            list_exp_responses.append(list_exp_response)

        with mock.patch("mlflow.utils.rest_utils.http_request",
                        side_effect=list_exp_responses):
            for idx, experiments in enumerate(
                    store._paginate_list_experiments(ViewType.ACTIVE_ONLY)):
                assert experiments[0].name == str(next_page_tokens[idx])
                assert experiments.token == next_page_tokens[idx]
Exemplo n.º 2
0
    def test_creation_and_hydration(self):
        exp_id = random_int()
        name = "exp_%d_%d" % (random_int(), random_int())
        location = random_file(".json")

        exp = Experiment(exp_id, name, location)
        self._check(exp, exp_id, name, location)

        as_dict = {"experiment_id": exp_id, "name": name, "artifact_location": location}
        self.assertEqual(dict(exp), as_dict)

        proto = exp.to_proto()
        exp2 = Experiment.from_proto(proto)
        self._check(exp2, exp_id, name, location)

        exp3 = Experiment.from_dictionary(as_dict)
        self._check(exp3, exp_id, name, location)
Exemplo n.º 3
0
    def test_creation_and_hydration(self):
        exp_id = str(random_int())
        name = "exp_%d_%d" % (random_int(), random_int())
        lifecycle_stage = LifecycleStage.ACTIVE
        location = random_file(".json")

        exp = Experiment(exp_id, name, location, lifecycle_stage)
        self._check(exp, exp_id, name, location, lifecycle_stage)

        as_dict = {
            "experiment_id": exp_id,
            "name": name,
            "artifact_location": location,
            "lifecycle_stage": lifecycle_stage
        }
        self.assertEqual(dict(exp), as_dict)

        proto = exp.to_proto()
        exp2 = Experiment.from_proto(proto)
        self._check(exp2, exp_id, name, location, lifecycle_stage)

        exp3 = Experiment.from_dictionary(as_dict)
        self._check(exp3, exp_id, name, location, lifecycle_stage)
Exemplo n.º 4
0
    def test_get_experiment_by_name(self, store_class):
        creds = MlflowHostCreds('https://hello')
        store = store_class(lambda: creds)
        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            response = mock.MagicMock
            response.status_code = 200
            experiment = Experiment(
                experiment_id="123", name="abc", artifact_location="/abc",
                lifecycle_stage=LifecycleStage.ACTIVE)
            response.text = json.dumps({
                "experiment": json.loads(message_to_json(experiment.to_proto()))})
            mock_http.return_value = response
            result = store.get_experiment_by_name("abc")
            expected_message0 = GetExperimentByName(experiment_name="abc")
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message0))
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage
            # Test GetExperimentByName against nonexistent experiment
            mock_http.reset_mock()
            nonexistent_exp_response = mock.MagicMock
            nonexistent_exp_response.status_code = 404
            nonexistent_exp_response.text =\
                MlflowException("Exp doesn't exist!", RESOURCE_DOES_NOT_EXIST).serialize_as_json()
            mock_http.return_value = nonexistent_exp_response
            assert store.get_experiment_by_name("nonexistent-experiment") is None
            expected_message1 = GetExperimentByName(experiment_name="nonexistent-experiment")
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message1))
            assert mock_http.call_count == 1

            # Test REST client behavior against a mocked old server, which has handler for
            # ListExperiments but not GetExperimentByName
            mock_http.reset_mock()
            list_exp_response = mock.MagicMock
            list_exp_response.text = json.dumps({
                "experiments": [json.loads(message_to_json(experiment.to_proto()))]})
            list_exp_response.status_code = 200

            def response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                if kwargs.get('endpoint') == "/api/2.0/mlflow/experiments/get-by-name":
                    raise MlflowException("GetExperimentByName is not implemented",
                                          ENDPOINT_NOT_FOUND)
                else:
                    return list_exp_response

            mock_http.side_effect = response_fn
            result = store.get_experiment_by_name("abc")
            expected_message2 = ListExperiments(view_type=ViewType.ALL)
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message0))
            self._verify_requests(mock_http, creds,
                                  "experiments/list", "GET",
                                  message_to_json(expected_message2))
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage

            # Verify that REST client won't fall back to ListExperiments for 429 errors (hitting
            # rate limits)
            mock_http.reset_mock()

            def rate_limit_response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                raise MlflowException("Hit rate limit on GetExperimentByName",
                                      REQUEST_LIMIT_EXCEEDED)

            mock_http.side_effect = rate_limit_response_fn
            with pytest.raises(MlflowException) as exc_info:
                store.get_experiment_by_name("imspamming")
            assert exc_info.value.error_code == ErrorCode.Name(REQUEST_LIMIT_EXCEEDED)
            assert mock_http.call_count == 1