示例#1
0
def test_search_model_versions(mock_get_request_message, mock_model_registry_store):
    mock_get_request_message.return_value = SearchModelVersions(filter="source_path = 'A/B/CD'")
    mvds = [
        ModelVersionDetailed(RegisteredModel(name="model_1"), version=5, creation_timestamp=100,
                             last_updated_timestamp=1200, description="v 5", user_id="u1",
                             current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex,
                             status="READY", status_message=None),
        ModelVersionDetailed(RegisteredModel(name="model_1"), version=12, creation_timestamp=110,
                             last_updated_timestamp=2000, description="v 12", user_id="u2",
                             current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex,
                             status="READY", status_message=None),
        ModelVersionDetailed(RegisteredModel(name="ads_model"), version=8, creation_timestamp=200,
                             last_updated_timestamp=2000, description="v 8", user_id="u1",
                             current_stage="Staging", source="A/B/CD", run_id=uuid.uuid4().hex,
                             status="READY", status_message=None),
        ModelVersionDetailed(RegisteredModel(name="fraud_detection_model"), version=345,
                             creation_timestamp=1000, last_updated_timestamp=1001,
                             description="newest version",  user_id="u12", current_stage="None",
                             source="A/B/CD",  run_id=uuid.uuid4().hex, status="READY",
                             status_message=None),
    ]
    mock_model_registry_store.search_model_versions.return_value = mvds
    resp = _search_model_versions()
    args, _ = mock_model_registry_store.search_model_versions.call_args
    assert args == ("source_path = 'A/B/CD'", )
    assert json.loads(resp.get_data()) == {"model_versions_detailed": jsonify(mvds)}
示例#2
0
def test_get_latest_versions(mock_get_request_message,
                             mock_model_registry_store):
    rm = RegisteredModel("model1")
    mock_get_request_message.return_value = GetLatestVersions(
        registered_model=rm.to_proto())
    mvds = [
        ModelVersionDetailed(registered_model=rm,
                             version=5,
                             creation_timestamp=1,
                             last_updated_timestamp=12,
                             description="v 5",
                             user_id="u1",
                             current_stage="Production",
                             source="A/B",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
        ModelVersionDetailed(registered_model=rm,
                             version=1,
                             creation_timestamp=1,
                             last_updated_timestamp=1200,
                             description="v 1",
                             user_id="u1",
                             current_stage="Archived",
                             source="A/B2",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
        ModelVersionDetailed(registered_model=rm,
                             version=12,
                             creation_timestamp=100,
                             last_updated_timestamp=None,
                             description="v 12",
                             user_id="u2",
                             current_stage="Staging",
                             source="A/B3",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
    ]
    mock_model_registry_store.get_latest_versions.return_value = mvds
    resp = _get_latest_versions()
    args, _ = mock_model_registry_store.get_latest_versions.call_args
    assert args == (rm, [])
    assert json.loads(resp.get_data()) == {
        "model_versions_detailed": jsonify(mvds)
    }

    for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]:
        mock_get_request_message.return_value = GetLatestVersions(
            registered_model=rm.to_proto(), stages=stages)
        _get_latest_versions()
        args, _ = mock_model_registry_store.get_latest_versions.call_args
        assert args == (rm, stages)
示例#3
0
 def to_mlflow_detailed_entity(self):
     return ModelVersionDetailed(self.registered_model.to_mlflow_entity(),
                                 self.version, self.creation_time,
                                 self.last_updated_time, self.description,
                                 self.user_id, self.current_stage,
                                 self.source, self.run_id, self.status,
                                 self.status_message)
示例#4
0
def _model_version_detailed(name,
                            version,
                            stage,
                            source="some:/source",
                            run_id="run13579"):
    return ModelVersionDetailed(RegisteredModel(name), version, "2345671890",
                                "234567890", "some description", "UserID",
                                stage, source, run_id)
    def get_model_version_details(self, model_version):
        """
        :param model_version: :py:class:`mlflow.entities.model_registry.ModelVersion` object.

        :return: A single :py:class:`mlflow.entities.model_registry.ModelVersionDetailed` object.
        """
        req_body = message_to_json(GetModelVersionDetails(model_version=model_version.to_proto()))
        response_proto = self._call_endpoint(GetModelVersionDetails, req_body)
        return ModelVersionDetailed.from_proto(response_proto.model_version_detailed)
示例#6
0
def test_get_model_version_details(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mv = ModelVersion(registered_model=rm, version=32)
    mock_get_request_message.return_value = GetModelVersionDetails(model_version=mv.to_proto())
    mvd = ModelVersionDetailed(registered_model=rm, version=5, creation_timestamp=1,
                               last_updated_timestamp=12, description="v 5", user_id="u1",
                               current_stage="Production", source="A/B", run_id=uuid.uuid4().hex,
                               status="READY", status_message=None)
    mock_model_registry_store.get_model_version_details.return_value = mvd
    resp = _get_model_version_details()
    args, _ = mock_model_registry_store.get_model_version_details.call_args
    assert args == (mv, )
    assert json.loads(resp.get_data()) == {"model_version_detailed": jsonify(mvd)}
    def search_model_versions(self, filter_string):
        """
        Search for model versions in backend that satisfy the filter criteria.

        :param filter_string: A filter string expression. Currently supports a single filter
                              condition either name of model like ``name = 'model_name'`` or
                              ``run_id = '...'``.

        :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersionDetailed`
                 objects.
        """
        req_body = message_to_json(SearchModelVersions(filter=filter_string))
        response_proto = self._call_endpoint(SearchModelVersions, req_body)
        model_versions_detailed = [ModelVersionDetailed.from_proto(mvd)
                                   for mvd in response_proto.model_versions_detailed]
        return PagedList(model_versions_detailed, None)
    def get_latest_versions(self, registered_model, stages=None):
        """
        Latest version models for each requested stage. If no ``stages`` argument is provided,
        returns the latest version for each stage.

        :param registered_model: :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
        :param stages: List of desired stages. If input list is None, return latest versions for
                       for 'Staging' and 'Production' stages.

        :return: List of `:py:class:`mlflow.entities.model_registry.ModelVersionDetailed` objects.
        """
        req_body = message_to_json(GetLatestVersions(
            registered_model=registered_model.to_proto(), stages=stages))
        response_proto = self._call_endpoint(GetLatestVersions, req_body)
        return [ModelVersionDetailed.from_proto(model_version_detailed)
                for model_version_detailed in response_proto.model_versions_detailed]
def test_models_artifact_repo_init_with_stage_uri(
        host_creds_mock):  # pylint: disable=unused-argument
    model_uri = "models:/MyModel/Production"
    artifact_location = "dbfs://databricks/mlflow-registry/12345/models/keras-model"
    model_version_detailed = ModelVersionDetailed(RegisteredModel("MyModel"), 10, "2345671890",
                                                  "234567890", "some description", "UserID",
                                                  "Production", "source", "run12345")
    get_latest_versions_patch = mock.patch.object(MlflowClient, "get_latest_versions",
                                                  return_value=[model_version_detailed])
    get_model_version_download_uri_patch = mock.patch.object(MlflowClient,
                                                             "get_model_version_download_uri",
                                                             return_value=artifact_location)
    with get_latest_versions_patch, get_model_version_download_uri_patch:
        models_repo = ModelsArtifactRepository(model_uri)
        assert models_repo.artifact_uri == model_uri
        assert isinstance(models_repo.repo, DbfsRestArtifactRepository)
        assert models_repo.repo.artifact_uri == artifact_location