def test_search_model_versions(mock_get_request_message, mock_model_registry_store): mock_get_request_message.return_value = SearchModelVersions(filter="source_path = 'A/B/CD'") mvds = [ ModelVersionDetailed(RegisteredModel(name="model_1"), version=5, creation_timestamp=100, last_updated_timestamp=1200, description="v 5", user_id="u1", current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(RegisteredModel(name="model_1"), version=12, creation_timestamp=110, last_updated_timestamp=2000, description="v 12", user_id="u2", current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(RegisteredModel(name="ads_model"), version=8, creation_timestamp=200, last_updated_timestamp=2000, description="v 8", user_id="u1", current_stage="Staging", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(RegisteredModel(name="fraud_detection_model"), version=345, creation_timestamp=1000, last_updated_timestamp=1001, description="newest version", user_id="u12", current_stage="None", source="A/B/CD", run_id=uuid.uuid4().hex, status="READY", status_message=None), ] mock_model_registry_store.search_model_versions.return_value = mvds resp = _search_model_versions() args, _ = mock_model_registry_store.search_model_versions.call_args assert args == ("source_path = 'A/B/CD'", ) assert json.loads(resp.get_data()) == {"model_versions_detailed": jsonify(mvds)}
def test_get_latest_versions(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mock_get_request_message.return_value = GetLatestVersions( registered_model=rm.to_proto()) mvds = [ ModelVersionDetailed(registered_model=rm, version=5, creation_timestamp=1, last_updated_timestamp=12, description="v 5", user_id="u1", current_stage="Production", source="A/B", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(registered_model=rm, version=1, creation_timestamp=1, last_updated_timestamp=1200, description="v 1", user_id="u1", current_stage="Archived", source="A/B2", run_id=uuid.uuid4().hex, status="READY", status_message=None), ModelVersionDetailed(registered_model=rm, version=12, creation_timestamp=100, last_updated_timestamp=None, description="v 12", user_id="u2", current_stage="Staging", source="A/B3", run_id=uuid.uuid4().hex, status="READY", status_message=None), ] mock_model_registry_store.get_latest_versions.return_value = mvds resp = _get_latest_versions() args, _ = mock_model_registry_store.get_latest_versions.call_args assert args == (rm, []) assert json.loads(resp.get_data()) == { "model_versions_detailed": jsonify(mvds) } for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]: mock_get_request_message.return_value = GetLatestVersions( registered_model=rm.to_proto(), stages=stages) _get_latest_versions() args, _ = mock_model_registry_store.get_latest_versions.call_args assert args == (rm, stages)
def to_mlflow_detailed_entity(self): return ModelVersionDetailed(self.registered_model.to_mlflow_entity(), self.version, self.creation_time, self.last_updated_time, self.description, self.user_id, self.current_stage, self.source, self.run_id, self.status, self.status_message)
def _model_version_detailed(name, version, stage, source="some:/source", run_id="run13579"): return ModelVersionDetailed(RegisteredModel(name), version, "2345671890", "234567890", "some description", "UserID", stage, source, run_id)
def get_model_version_details(self, model_version): """ :param model_version: :py:class:`mlflow.entities.model_registry.ModelVersion` object. :return: A single :py:class:`mlflow.entities.model_registry.ModelVersionDetailed` object. """ req_body = message_to_json(GetModelVersionDetails(model_version=model_version.to_proto())) response_proto = self._call_endpoint(GetModelVersionDetails, req_body) return ModelVersionDetailed.from_proto(response_proto.model_version_detailed)
def test_get_model_version_details(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mv = ModelVersion(registered_model=rm, version=32) mock_get_request_message.return_value = GetModelVersionDetails(model_version=mv.to_proto()) mvd = ModelVersionDetailed(registered_model=rm, version=5, creation_timestamp=1, last_updated_timestamp=12, description="v 5", user_id="u1", current_stage="Production", source="A/B", run_id=uuid.uuid4().hex, status="READY", status_message=None) mock_model_registry_store.get_model_version_details.return_value = mvd resp = _get_model_version_details() args, _ = mock_model_registry_store.get_model_version_details.call_args assert args == (mv, ) assert json.loads(resp.get_data()) == {"model_version_detailed": jsonify(mvd)}
def search_model_versions(self, filter_string): """ Search for model versions in backend that satisfy the filter criteria. :param filter_string: A filter string expression. Currently supports a single filter condition either name of model like ``name = 'model_name'`` or ``run_id = '...'``. :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersionDetailed` objects. """ req_body = message_to_json(SearchModelVersions(filter=filter_string)) response_proto = self._call_endpoint(SearchModelVersions, req_body) model_versions_detailed = [ModelVersionDetailed.from_proto(mvd) for mvd in response_proto.model_versions_detailed] return PagedList(model_versions_detailed, None)
def get_latest_versions(self, registered_model, stages=None): """ Latest version models for each requested stage. If no ``stages`` argument is provided, returns the latest version for each stage. :param registered_model: :py:class:`mlflow.entities.model_registry.RegisteredModel` object. :param stages: List of desired stages. If input list is None, return latest versions for for 'Staging' and 'Production' stages. :return: List of `:py:class:`mlflow.entities.model_registry.ModelVersionDetailed` objects. """ req_body = message_to_json(GetLatestVersions( registered_model=registered_model.to_proto(), stages=stages)) response_proto = self._call_endpoint(GetLatestVersions, req_body) return [ModelVersionDetailed.from_proto(model_version_detailed) for model_version_detailed in response_proto.model_versions_detailed]
def test_models_artifact_repo_init_with_stage_uri( host_creds_mock): # pylint: disable=unused-argument model_uri = "models:/MyModel/Production" artifact_location = "dbfs://databricks/mlflow-registry/12345/models/keras-model" model_version_detailed = ModelVersionDetailed(RegisteredModel("MyModel"), 10, "2345671890", "234567890", "some description", "UserID", "Production", "source", "run12345") get_latest_versions_patch = mock.patch.object(MlflowClient, "get_latest_versions", return_value=[model_version_detailed]) get_model_version_download_uri_patch = mock.patch.object(MlflowClient, "get_model_version_download_uri", return_value=artifact_location) with get_latest_versions_patch, get_model_version_download_uri_patch: models_repo = ModelsArtifactRepository(model_uri) assert models_repo.artifact_uri == model_uri assert isinstance(models_repo.repo, DbfsRestArtifactRepository) assert models_repo.repo.artifact_uri == artifact_location