Example #1
0
def test_parse_legacy_experiment():
    in_json = {"experiment_id": 123, "name": "name", "unknown": "field"}
    message = ProtoExperiment()
    parse_dict(in_json, message)
    experiment = Experiment.from_proto(message)
    assert experiment.experiment_id == "123"
    assert experiment.name == 'name'
    assert experiment.artifact_location == ''
Example #2
0
def _read_persisted_experiment_dict(experiment_dict):
    dict_copy = experiment_dict.copy()

    # 'experiment_id' was changed from int to string, so we must cast to string
    # when reading legacy experiments
    if isinstance(dict_copy['experiment_id'], int):
        dict_copy['experiment_id'] = str(dict_copy['experiment_id'])
    return Experiment.from_dictionary(dict_copy)
Example #3
0
def test_message_to_json():
    json_out = message_to_json(
        Experiment("123", "name", "arty", 'active').to_proto())
    assert json.loads(json_out) == {
        "experiment_id": "123",
        "name": "name",
        "artifact_location": "arty",
        "lifecycle_stage": 'active',
    }
Example #4
0
 def list_experiments(self, view_type=ViewType.ACTIVE_ONLY):
     """
     :return: a list of all known Experiment objects
     """
     req_body = message_to_json(ListExperiments(view_type=view_type))
     response_proto = self._call_endpoint(ListExperiments, req_body)
     return [
         Experiment.from_proto(experiment_proto)
         for experiment_proto in response_proto.experiments
     ]
Example #5
0
    def to_mlflow_entity(self):
        """
        Convert DB model to corresponding MLflow entity.

        :return: :py:class:`mlflow.entities.Experiment`.
        """
        return Experiment(experiment_id=str(self.experiment_id),
                          name=self.name,
                          artifact_location=self.artifact_location,
                          lifecycle_stage=self.lifecycle_stage,
                          tags=[t.to_mlflow_entity() for t in self.tags])
Example #6
0
    def get_experiment(self, experiment_id):
        """
        Fetch the experiment from the backend store.

        :param experiment_id: String id for the experiment

        :return: A single :py:class:`mlflow.entities.Experiment` object if it exists,
        otherwise raises an Exception.
        """
        req_body = message_to_json(
            GetExperiment(experiment_id=str(experiment_id)))
        response_proto = self._call_endpoint(GetExperiment, req_body)
        return Experiment.from_proto(response_proto.experiment)
Example #7
0
 def _create_experiment_with_id(self, name, experiment_id, artifact_uri):
     artifact_uri = artifact_uri or append_to_uri_path(
         self.artifact_root_uri, str(experiment_id))
     self._check_root_dir()
     meta_dir = mkdir(self.root_directory, str(experiment_id))
     experiment = Experiment(experiment_id, name, artifact_uri,
                             LifecycleStage.ACTIVE)
     experiment_dict = dict(experiment)
     # tags are added to the file system and are not written to this dict on write
     # As such, we should not include them in the meta file.
     del experiment_dict['tags']
     write_yaml(meta_dir, FileStore.META_DATA_FILE_NAME, experiment_dict)
     return experiment_id
Example #8
0
 def get_experiment_by_name(self, experiment_name):
     try:
         req_body = message_to_json(
             GetExperimentByName(experiment_name=experiment_name))
         response_proto = self._call_endpoint(GetExperimentByName, req_body)
         return Experiment.from_proto(response_proto.experiment)
     except MlflowException as e:
         if e.error_code == databricks_pb2.ErrorCode.Name(
                 databricks_pb2.RESOURCE_DOES_NOT_EXIST):
             return None
         elif e.error_code == databricks_pb2.ErrorCode.Name(
                 databricks_pb2.ENDPOINT_NOT_FOUND):
             # Fall back to using ListExperiments-based implementation.
             for experiment in self.list_experiments(ViewType.ALL):
                 if experiment.name == experiment_name:
                     return experiment
             return None
         raise e
Example #9
0
    def test_get_experiment_by_name(self, store_class):
        creds = MlflowHostCreds('https://hello')
        store = store_class(lambda: creds)
        with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http:
            response = mock.MagicMock
            response.status_code = 200
            experiment = Experiment(
                experiment_id="123", name="abc", artifact_location="/abc",
                lifecycle_stage=LifecycleStage.ACTIVE)
            response.text = json.dumps({
                "experiment": json.loads(message_to_json(experiment.to_proto()))})
            mock_http.return_value = response
            result = store.get_experiment_by_name("abc")
            expected_message0 = GetExperimentByName(experiment_name="abc")
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message0))
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage
            # Test GetExperimentByName against nonexistent experiment
            mock_http.reset_mock()
            nonexistent_exp_response = mock.MagicMock
            nonexistent_exp_response.status_code = 404
            nonexistent_exp_response.text =\
                MlflowException("Exp doesn't exist!", RESOURCE_DOES_NOT_EXIST).serialize_as_json()
            mock_http.return_value = nonexistent_exp_response
            assert store.get_experiment_by_name("nonexistent-experiment") is None
            expected_message1 = GetExperimentByName(experiment_name="nonexistent-experiment")
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message1))
            assert mock_http.call_count == 1

            # Test REST client behavior against a mocked old server, which has handler for
            # ListExperiments but not GetExperimentByName
            mock_http.reset_mock()
            list_exp_response = mock.MagicMock
            list_exp_response.text = json.dumps({
                "experiments": [json.loads(message_to_json(experiment.to_proto()))]})
            list_exp_response.status_code = 200

            def response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                if kwargs.get('endpoint') == "/api/2.0/mlflow/experiments/get-by-name":
                    raise MlflowException("GetExperimentByName is not implemented",
                                          ENDPOINT_NOT_FOUND)
                else:
                    return list_exp_response

            mock_http.side_effect = response_fn
            result = store.get_experiment_by_name("abc")
            expected_message2 = ListExperiments(view_type=ViewType.ALL)
            self._verify_requests(mock_http, creds,
                                  "experiments/get-by-name", "GET",
                                  message_to_json(expected_message0))
            self._verify_requests(mock_http, creds,
                                  "experiments/list", "GET",
                                  message_to_json(expected_message2))
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage

            # Verify that REST client won't fall back to ListExperiments for 429 errors (hitting
            # rate limits)
            mock_http.reset_mock()

            def rate_limit_response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                raise MlflowException("Hit rate limit on GetExperimentByName",
                                      REQUEST_LIMIT_EXCEEDED)

            mock_http.side_effect = rate_limit_response_fn
            with pytest.raises(MlflowException) as exc_info:
                store.get_experiment_by_name("imspamming")
            assert exc_info.value.error_code == ErrorCode.Name(REQUEST_LIMIT_EXCEEDED)
            assert mock_http.call_count == 1