def test_parse_legacy_experiment(): in_json = {"experiment_id": 123, "name": "name", "unknown": "field"} message = ProtoExperiment() parse_dict(in_json, message) experiment = Experiment.from_proto(message) assert experiment.experiment_id == "123" assert experiment.name == 'name' assert experiment.artifact_location == ''
def _read_persisted_experiment_dict(experiment_dict): dict_copy = experiment_dict.copy() # 'experiment_id' was changed from int to string, so we must cast to string # when reading legacy experiments if isinstance(dict_copy['experiment_id'], int): dict_copy['experiment_id'] = str(dict_copy['experiment_id']) return Experiment.from_dictionary(dict_copy)
def test_message_to_json(): json_out = message_to_json( Experiment("123", "name", "arty", 'active').to_proto()) assert json.loads(json_out) == { "experiment_id": "123", "name": "name", "artifact_location": "arty", "lifecycle_stage": 'active', }
def list_experiments(self, view_type=ViewType.ACTIVE_ONLY): """ :return: a list of all known Experiment objects """ req_body = message_to_json(ListExperiments(view_type=view_type)) response_proto = self._call_endpoint(ListExperiments, req_body) return [ Experiment.from_proto(experiment_proto) for experiment_proto in response_proto.experiments ]
def to_mlflow_entity(self): """ Convert DB model to corresponding MLflow entity. :return: :py:class:`mlflow.entities.Experiment`. """ return Experiment(experiment_id=str(self.experiment_id), name=self.name, artifact_location=self.artifact_location, lifecycle_stage=self.lifecycle_stage, tags=[t.to_mlflow_entity() for t in self.tags])
def get_experiment(self, experiment_id): """ Fetch the experiment from the backend store. :param experiment_id: String id for the experiment :return: A single :py:class:`mlflow.entities.Experiment` object if it exists, otherwise raises an Exception. """ req_body = message_to_json( GetExperiment(experiment_id=str(experiment_id))) response_proto = self._call_endpoint(GetExperiment, req_body) return Experiment.from_proto(response_proto.experiment)
def _create_experiment_with_id(self, name, experiment_id, artifact_uri): artifact_uri = artifact_uri or append_to_uri_path( self.artifact_root_uri, str(experiment_id)) self._check_root_dir() meta_dir = mkdir(self.root_directory, str(experiment_id)) experiment = Experiment(experiment_id, name, artifact_uri, LifecycleStage.ACTIVE) experiment_dict = dict(experiment) # tags are added to the file system and are not written to this dict on write # As such, we should not include them in the meta file. del experiment_dict['tags'] write_yaml(meta_dir, FileStore.META_DATA_FILE_NAME, experiment_dict) return experiment_id
def get_experiment_by_name(self, experiment_name): try: req_body = message_to_json( GetExperimentByName(experiment_name=experiment_name)) response_proto = self._call_endpoint(GetExperimentByName, req_body) return Experiment.from_proto(response_proto.experiment) except MlflowException as e: if e.error_code == databricks_pb2.ErrorCode.Name( databricks_pb2.RESOURCE_DOES_NOT_EXIST): return None elif e.error_code == databricks_pb2.ErrorCode.Name( databricks_pb2.ENDPOINT_NOT_FOUND): # Fall back to using ListExperiments-based implementation. for experiment in self.list_experiments(ViewType.ALL): if experiment.name == experiment_name: return experiment return None raise e
def test_get_experiment_by_name(self, store_class): creds = MlflowHostCreds('https://hello') store = store_class(lambda: creds) with mock.patch('mlflow.utils.rest_utils.http_request') as mock_http: response = mock.MagicMock response.status_code = 200 experiment = Experiment( experiment_id="123", name="abc", artifact_location="/abc", lifecycle_stage=LifecycleStage.ACTIVE) response.text = json.dumps({ "experiment": json.loads(message_to_json(experiment.to_proto()))}) mock_http.return_value = response result = store.get_experiment_by_name("abc") expected_message0 = GetExperimentByName(experiment_name="abc") self._verify_requests(mock_http, creds, "experiments/get-by-name", "GET", message_to_json(expected_message0)) assert result.experiment_id == experiment.experiment_id assert result.name == experiment.name assert result.artifact_location == experiment.artifact_location assert result.lifecycle_stage == experiment.lifecycle_stage # Test GetExperimentByName against nonexistent experiment mock_http.reset_mock() nonexistent_exp_response = mock.MagicMock nonexistent_exp_response.status_code = 404 nonexistent_exp_response.text =\ MlflowException("Exp doesn't exist!", RESOURCE_DOES_NOT_EXIST).serialize_as_json() mock_http.return_value = nonexistent_exp_response assert store.get_experiment_by_name("nonexistent-experiment") is None expected_message1 = GetExperimentByName(experiment_name="nonexistent-experiment") self._verify_requests(mock_http, creds, "experiments/get-by-name", "GET", message_to_json(expected_message1)) assert mock_http.call_count == 1 # Test REST client behavior against a mocked old server, which has handler for # ListExperiments but not GetExperimentByName mock_http.reset_mock() list_exp_response = mock.MagicMock list_exp_response.text = json.dumps({ "experiments": [json.loads(message_to_json(experiment.to_proto()))]}) list_exp_response.status_code = 200 def response_fn(*args, **kwargs): # pylint: disable=unused-argument if kwargs.get('endpoint') == "/api/2.0/mlflow/experiments/get-by-name": raise MlflowException("GetExperimentByName is not implemented", ENDPOINT_NOT_FOUND) else: return list_exp_response mock_http.side_effect = response_fn result = store.get_experiment_by_name("abc") expected_message2 = ListExperiments(view_type=ViewType.ALL) self._verify_requests(mock_http, creds, "experiments/get-by-name", "GET", message_to_json(expected_message0)) self._verify_requests(mock_http, creds, "experiments/list", "GET", message_to_json(expected_message2)) assert result.experiment_id == experiment.experiment_id assert result.name == experiment.name assert result.artifact_location == experiment.artifact_location assert result.lifecycle_stage == experiment.lifecycle_stage # Verify that REST client won't fall back to ListExperiments for 429 errors (hitting # rate limits) mock_http.reset_mock() def rate_limit_response_fn(*args, **kwargs): # pylint: disable=unused-argument raise MlflowException("Hit rate limit on GetExperimentByName", REQUEST_LIMIT_EXCEEDED) mock_http.side_effect = rate_limit_response_fn with pytest.raises(MlflowException) as exc_info: store.get_experiment_by_name("imspamming") assert exc_info.value.error_code == ErrorCode.Name(REQUEST_LIMIT_EXCEEDED) assert mock_http.call_count == 1