def test_update_model_version_stage(self, mock_http):
     rm = RegisteredModel("model_1")
     mv = ModelVersion(rm, 5)
     self.store.update_model_version(model_version=mv, stage="prod")
     self._verify_requests(
         mock_http, "model-versions/update", "PATCH",
         UpdateModelVersion(model_version=mv.to_proto(), stage="prod"))
 def test_get_model_version_stages(self, mock_http):
     rm = RegisteredModel("model_11")
     mv = ModelVersion(rm, 8)
     self.store.get_model_version_stages(model_version=mv)
     self._verify_requests(
         mock_http, "model-versions/get-stages", "POST",
         GetModelVersionStages(model_version=mv.to_proto()))
Esempio n. 3
0
def test_search_model_versions(mock_get_request_message, mock_model_registry_store):
    mock_get_request_message.return_value = SearchModelVersions(filter="source_path = 'A/B/CD'")
    mvds = [
        ModelVersion(name="model_1", version="5", creation_timestamp=100,
                     last_updated_timestamp=1200, description="v 5", user_id="u1",
                     current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex,
                     status="READY", status_message=None),
        ModelVersion(name="model_1", version="12", creation_timestamp=110,
                     last_updated_timestamp=2000, description="v 12", user_id="u2",
                     current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex,
                     status="READY", status_message=None),
        ModelVersion(name="ads_model", version="8", creation_timestamp=200,
                     last_updated_timestamp=2000, description="v 8", user_id="u1",
                     current_stage="Staging", source="A/B/CD", run_id=uuid.uuid4().hex,
                     status="READY", status_message=None),
        ModelVersion(name="fraud_detection_model", version="345",
                     creation_timestamp=1000, last_updated_timestamp=1001,
                     description="newest version", user_id="u12", current_stage="None",
                     source="A/B/CD", run_id=uuid.uuid4().hex, status="READY",
                     status_message=None),
    ]
    mock_model_registry_store.search_model_versions.return_value = mvds
    resp = _search_model_versions()
    args, _ = mock_model_registry_store.search_model_versions.call_args
    assert args == ("source_path = 'A/B/CD'",)
    assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)}
Esempio n. 4
0
def test_get_model_name_and_version_with_latest():
    with mock.patch.object(
            MlflowClient,
            "get_latest_versions",
            return_value=[
                ModelVersion(name="mv1",
                             version="10",
                             creation_timestamp=123,
                             current_stage="Production"),
                ModelVersion(name="mv3",
                             version="20",
                             creation_timestamp=125,
                             current_stage="None"),
                ModelVersion(name="mv2",
                             version="15",
                             creation_timestamp=124,
                             current_stage="Staging"),
            ],
    ) as mlflow_client_mock:
        assert get_model_name_and_version(MlflowClient(),
                                          "models:/AdsModel1/latest") == (
                                              "AdsModel1",
                                              "20",
                                          )
        mlflow_client_mock.assert_called_once_with("AdsModel1", None)
Esempio n. 5
0
def test_delete_model_version(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mv = ModelVersion(registered_model=rm, version=32)
    mock_get_request_message.return_value = DeleteModelVersion(model_version=mv.to_proto())
    _delete_model_version()
    args, _ = mock_model_registry_store.delete_model_version.call_args
    assert args == (mv, )
 def test_get_model_version_download_uri(self, mock_http):
     rm = RegisteredModel("model_11")
     mv = ModelVersion(rm, 8)
     self.store.get_model_version_download_uri(model_version=mv)
     self._verify_requests(
         mock_http, "model-versions/get-download-uri", "POST",
         GetModelVersionDownloadUri(model_version=mv.to_proto()))
Esempio n. 7
0
def test_get_latest_versions(mock_get_request_message, mock_model_registry_store):
    name = "model1"
    mock_get_request_message.return_value = GetLatestVersions(name=name)
    mvds = [
        ModelVersion(name=name, version="5", creation_timestamp=1,
                     last_updated_timestamp=12, description="v 5", user_id="u1",
                     current_stage="Production", source="A/B", run_id=uuid.uuid4().hex,
                     status="READY", status_message=None),
        ModelVersion(name=name, version="1", creation_timestamp=1,
                     last_updated_timestamp=1200, description="v 1", user_id="u1",
                     current_stage="Archived", source="A/B2", run_id=uuid.uuid4().hex,
                     status="READY", status_message=None),
        ModelVersion(name=name, version="12", creation_timestamp=100,
                     last_updated_timestamp=None, description="v 12", user_id="u2",
                     current_stage="Staging", source="A/B3", run_id=uuid.uuid4().hex,
                     status="READY", status_message=None),
    ]
    mock_model_registry_store.get_latest_versions.return_value = mvds
    resp = _get_latest_versions()
    _, args = mock_model_registry_store.get_latest_versions.call_args
    assert args == {"name": name, "stages": []}
    assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)}

    for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]:
        mock_get_request_message.return_value = GetLatestVersions(name=name,
                                                                  stages=stages)
        _get_latest_versions()
        _, args = mock_model_registry_store.get_latest_versions.call_args
        assert args == {"name": name, "stages": stages}
def test_search_model_versions(mock_store):
    mock_store.search_model_versions.return_value = [
        ModelVersion(name="Model 1", version="1", creation_timestamp=123),
        ModelVersion(name="Model 1", version="2", creation_timestamp=124),
    ]
    result = newModelRegistryClient().search_model_versions("name=Model 1")
    mock_store.search_model_versions.assert_called_once_with("name=Model 1")
    assert len(result) == 2
Esempio n. 9
0
def test_search_model_versions(mock_store):
    mock_store.search_model_versions.return_value = [
        ModelVersion(RegisteredModel("Model 1"), 1),
        ModelVersion(RegisteredModel("Model 1"), 2)
    ]
    result = newModelRegistryClient().search_model_versions("name=Model 1")
    mock_store.search_model_versions.assert_called_once_with("name=Model 1")
    assert len(result) == 2
Esempio n. 10
0
def test_update_model_version(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mv = ModelVersion(registered_model=rm, version=32)
    mock_get_request_message.return_value = UpdateModelVersion(model_version=mv.to_proto(),
                                                               stage="Production",
                                                               description="Great model!")
    _update_model_version()
    args, _ = mock_model_registry_store.update_model_version.call_args
    assert args == (mv, "Production", "Great model!")
Esempio n. 11
0
def test_get_model_version_download_uri(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mv = ModelVersion(registered_model=rm, version=32)
    mock_get_request_message.return_value = GetModelVersionDownloadUri(model_version=mv.to_proto())
    mock_model_registry_store.get_model_version_download_uri.return_value = "some/download/path"
    resp = _get_model_version_download_uri()
    args, _ = mock_model_registry_store.get_model_version_download_uri.call_args
    assert args == (mv, )
    assert json.loads(resp.get_data()) == {"artifact_uri": "some/download/path"}
 def test_update_model_version_decription(self, mock_http):
     rm = RegisteredModel("model_1")
     mv = ModelVersion(rm, 5)
     self.store.update_model_version(model_version=mv,
                                     description="test model version")
     self._verify_requests(
         mock_http, "model-versions/update", "PATCH",
         UpdateModelVersion(model_version=mv.to_proto(),
                            description="test model version"))
Esempio n. 13
0
def test_model_version_stages(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mv = ModelVersion(registered_model=rm, version=32)
    mock_get_request_message.return_value = GetModelVersionStages(model_version=mv.to_proto())
    stages = ["Stage1", "Production", "0", "5% traffic", "None"]
    mock_model_registry_store.get_model_version_stages.return_value = stages
    resp = _get_model_version_stages()
    args, _ = mock_model_registry_store.get_model_version_stages.call_args
    assert args == (mv, )
    assert json.loads(resp.get_data()) == {"stages": stages}
Esempio n. 14
0
def test_get_model_version_details(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mv = ModelVersion(registered_model=rm, version=32)
    mock_get_request_message.return_value = GetModelVersionDetails(model_version=mv.to_proto())
    mvd = ModelVersionDetailed(registered_model=rm, version=5, creation_timestamp=1,
                               last_updated_timestamp=12, description="v 5", user_id="u1",
                               current_stage="Production", source="A/B", run_id=uuid.uuid4().hex,
                               status="READY", status_message=None)
    mock_model_registry_store.get_model_version_details.return_value = mvd
    resp = _get_model_version_details()
    args, _ = mock_model_registry_store.get_model_version_details.call_args
    assert args == (mv, )
    assert json.loads(resp.get_data()) == {"model_version_detailed": jsonify(mvd)}
Esempio n. 15
0
def test_create_model_version(mock_get_request_message,
                              mock_model_registry_store):
    run_id = uuid.uuid4().hex
    tags = [
        ModelVersionTag(key="key", value="value"),
        ModelVersionTag(key="anotherKey", value="some other value")
    ]
    run_link = "localhost:5000/path/to/run"
    mock_get_request_message.return_value = CreateModelVersion(
        name="model_1",
        source="A/B",
        run_id=run_id,
        run_link=run_link,
        tags=[tag.to_proto() for tag in tags])
    mv = ModelVersion(name="model_1",
                      version="12",
                      creation_timestamp=123,
                      tags=tags,
                      run_link=run_link)
    mock_model_registry_store.create_model_version.return_value = mv
    resp = _create_model_version()
    _, args = mock_model_registry_store.create_model_version.call_args
    assert args["name"] == "model_1"
    assert args["source"] == "A/B"
    assert args["run_id"] == run_id
    assert {tag.key: tag.value
            for tag in args["tags"]} == {tag.key: tag.value
                                         for tag in tags}
    assert args["run_link"] == run_link
    assert json.loads(resp.get_data()) == {"model_version": jsonify(mv)}
def test_create_model_version_no_run_id(mock_store):
    name = "Model 1"
    version = "1"
    tags_dict = {"key": "value", "another key": "some other value"}
    tags = [ModelVersionTag(key, value) for key, value in tags_dict.items()]
    description = "best model ever"

    mock_store.create_model_version.return_value = ModelVersion(
        name=name,
        version=version,
        creation_timestamp=123,
        tags=tags,
        run_link=None,
        description=description,
    )
    result = newModelRegistryClient().create_model_version(
        name, "uri:/for/source", tags=tags_dict, run_link=None, description=description
    )
    mock_store.create_model_version.assert_called_once_with(
        name, "uri:/for/source", None, tags, None, description
    )

    assert result.name == name
    assert result.version == version
    assert result.tags == tags_dict
    assert result.run_id is None
Esempio n. 17
0
def test_models_artifact_repo_init_with_stage_uri(
    host_creds_mock,
):  # pylint: disable=unused-argument
    model_uri = "models:/MyModel/Production"
    artifact_location = "dbfs:/databricks/mlflow-registry/12345/models/keras-model"
    model_version_detailed = ModelVersion(
        "MyModel",
        "10",
        "2345671890",
        "234567890",
        "some description",
        "UserID",
        "Production",
        "source",
        "run12345",
    )
    get_latest_versions_patch = mock.patch.object(
        MlflowClient, "get_latest_versions", return_value=[model_version_detailed]
    )
    get_model_version_download_uri_patch = mock.patch.object(
        MlflowClient, "get_model_version_download_uri", return_value=artifact_location
    )
    with get_latest_versions_patch, get_model_version_download_uri_patch:
        models_repo = ModelsArtifactRepository(model_uri)
        assert models_repo.artifact_uri == model_uri
        assert isinstance(models_repo.repo, DbfsRestArtifactRepository)
        assert models_repo.repo.artifact_uri == artifact_location
Esempio n. 18
0
def test_create_model_version_run_link_with_configured_profile(
        mock_registry_store):
    experiment_id = "test-exp-id"
    hostname = "https://workspace.databricks.com/"
    workspace_id = "10002"
    run_id = "runid"
    workspace_url = construct_run_url(hostname, experiment_id, run_id,
                                      workspace_id)
    get_run_mock = mock.MagicMock()
    get_run_mock.return_value = Run(
        RunInfo(run_id, experiment_id, "userid", "status", 0, 1, None), None)
    with mock.patch(
            "mlflow.tracking.client.is_in_databricks_notebook",
            return_value=False
    ), mock.patch(
            "mlflow.tracking.client.get_workspace_info_from_databricks_secrets",
            return_value=(hostname, workspace_id),
    ):
        client = MlflowClient(tracking_uri="databricks",
                              registry_uri="otherplace")
        client.get_run = get_run_mock
        mock_registry_store.create_model_version.return_value = ModelVersion(
            "name",
            1,
            0,
            1,
            source="source",
            run_id=run_id,
            run_link=workspace_url)
        model_version = client.create_model_version("name", "source", "runid")
        assert model_version.run_link == workspace_url
        # verify that the client generated the right URL
        mock_registry_store.create_model_version.assert_called_once_with(
            "name", "source", "runid", [], workspace_url, None)
Esempio n. 19
0
def test_models_artifact_repo_init_with_stage_uri_and_not_using_databricks_registry(
):
    model_uri = "models:/MyModel/Staging"
    artifact_location = "s3://blah_bucket/"
    model_version_detailed = ModelVersion(
        "MyModel",
        "10",
        "2345671890",
        "234567890",
        "some description",
        "UserID",
        "Production",
        "source",
        "run12345",
    )
    get_latest_versions_patch = mock.patch.object(
        MlflowClient,
        "get_latest_versions",
        return_value=[model_version_detailed])
    get_model_version_download_uri_patch = mock.patch.object(
        MlflowClient,
        "get_model_version_download_uri",
        return_value=artifact_location)
    with get_latest_versions_patch, get_model_version_download_uri_patch, mock.patch(
            "mlflow.store.artifact.artifact_repository_registry.get_artifact_repository"
    ) as get_repo_mock:
        get_repo_mock.return_value = None
        ModelsArtifactRepository(model_uri)
        get_repo_mock.assert_called_once_with(artifact_location)
Esempio n. 20
0
def test_models_artifact_repo_init_with_stage_uri_and_db_profile():
    model_uri = "models://profile@databricks/MyModel/Staging"
    artifact_location = "dbfs:/databricks/mlflow-registry/12345/models/keras-model"
    final_uri = "dbfs://profile@databricks/databricks/mlflow-registry/12345/models/keras-model"
    model_version_detailed = ModelVersion(
        "MyModel",
        "10",
        "2345671890",
        "234567890",
        "some description",
        "UserID",
        "Production",
        "source",
        "run12345",
    )
    get_latest_versions_patch = mock.patch.object(
        MlflowClient, "get_latest_versions", return_value=[model_version_detailed]
    )
    get_model_version_download_uri_patch = mock.patch.object(
        MlflowClient, "get_model_version_download_uri", return_value=artifact_location
    )
    with get_latest_versions_patch, get_model_version_download_uri_patch, mock.patch(
        "mlflow.store.artifact.dbfs_artifact_repo.DbfsRestArtifactRepository", autospec=True
    ) as mock_repo:
        models_repo = ModelsArtifactRepository(model_uri)
        assert models_repo.artifact_uri == model_uri
        assert isinstance(models_repo.repo, DbfsRestArtifactRepository)
        mock_repo.assert_called_once_with(final_uri)
Esempio n. 21
0
    def create_model_version(self,
                             name,
                             source,
                             run_id,
                             tags=None,
                             run_link=None,
                             description=None):
        """
        Create a new model version from given source and run ID.

        :param name: Registered model name.
        :param source: Source path where the MLflow model is stored.
        :param run_id: Run ID from MLflow tracking server that generated the model.
        :param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
                     instances associated with this model version.
        :param run_link: Link to the run from an MLflow tracking server that generated this model.
        :param description: Description of the version.
        :return: A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
                 created in the backend.
        """
        proto_tags = [tag.to_proto() for tag in tags or []]
        req_body = message_to_json(
            CreateModelVersion(
                name=name,
                source=source,
                run_id=run_id,
                run_link=run_link,
                tags=proto_tags,
                description=description,
            ))
        response_proto = self._call_endpoint(CreateModelVersion, req_body)
        return ModelVersion.from_proto(response_proto.model_version)
Esempio n. 22
0
def test_create_model_version_run_link_in_notebook_with_default_profile(
        mock_registry_store):
    experiment_id = 'test-exp-id'
    hostname = 'https://workspace.databricks.com/'
    workspace_id = '10002'
    run_id = 'runid'
    workspace_url = construct_run_url(hostname, experiment_id, run_id,
                                      workspace_id)
    get_run_mock = mock.MagicMock()
    get_run_mock.return_value = Run(
        RunInfo(run_id, experiment_id, 'userid', 'status', 0, 1, None), None)
    with mock.patch('mlflow.tracking.client.is_in_databricks_notebook',
                    return_value=True), \
            mock.patch('mlflow.tracking.client.get_workspace_info_from_dbutils',
                       return_value=(hostname, workspace_id)):
        client = MlflowClient(tracking_uri='databricks',
                              registry_uri='otherplace')
        client.get_run = get_run_mock
        mock_registry_store.create_model_version.return_value = \
            ModelVersion('name', 1, 0, 1, source='source', run_id=run_id, run_link=workspace_url)
        model_version = client.create_model_version('name', 'source', 'runid')
        assert (model_version.run_link == workspace_url)
        # verify that the client generated the right URL
        mock_registry_store.create_model_version.assert_called_once_with(
            "name", 'source', 'runid', [], workspace_url)
Esempio n. 23
0
def _get_model_version_stages():
    request_message = _get_request_message(GetModelVersionStages())
    stages = _get_model_registry_store().get_model_version_stages(
        ModelVersion.from_proto(request_message.model_version))
    response_message = GetModelVersionStages.Response()
    response_message.stages.extend(stages)
    return _wrap_response(response_message)
Esempio n. 24
0
def _get_model_version_download_uri():
    request_message = _get_request_message(GetModelVersionDownloadUri())
    download_uri = _get_model_registry_store().get_model_version_download_uri(
        ModelVersion.from_proto(request_message.model_version))
    response_message = GetModelVersionDownloadUri.Response(
        artifact_uri=download_uri)
    return _wrap_response(response_message)
Esempio n. 25
0
def _model_version(name,
                   version,
                   stage,
                   source="some:/source",
                   run_id="run13579"):
    return ModelVersion(name, version, "2345671890", "234567890",
                        "some description", "UserID", stage, source, run_id)
Esempio n. 26
0
def _get_model_version_details():
    request_message = _get_request_message(GetModelVersionDetails())
    model_version_detailed = _get_model_registry_store().get_model_version_details(
        ModelVersion.from_proto(request_message.model_version))
    response_proto = model_version_detailed.to_proto()
    response_message = GetModelVersionDetails.Response(model_version_detailed=response_proto)
    return _wrap_response(response_message)
 def test_init_with_stage_uri_and_profile_is_inferred(
         self, stage_uri_without_profile):
     model_version_detailed = ModelVersion(
         MOCK_MODEL_NAME,
         MOCK_MODEL_VERSION,
         "2345671890",
         "234567890",
         "some description",
         "UserID",
         "Production",
         "source",
         "run12345",
     )
     get_latest_versions_patch = mock.patch.object(
         MlflowClient,
         "get_latest_versions",
         return_value=[model_version_detailed])
     with get_latest_versions_patch, mock.patch(
             "mlflow.store.artifact.utils.models.mlflow.get_registry_uri",
             return_value=MOCK_PROFILE,
     ), mock.patch("mlflow.tracking.get_registry_uri",
                   return_value=MOCK_PROFILE):
         repo = DatabricksModelsArtifactRepository(
             stage_uri_without_profile)
         assert repo.artifact_uri == stage_uri_without_profile
         assert repo.model_name == MOCK_MODEL_NAME
         assert repo.model_version == MOCK_MODEL_VERSION
         assert repo.databricks_profile_uri == MOCK_PROFILE
Esempio n. 28
0
 def get_model_version_details(self, name, version):
     """
     :param name: Name of the containing registered model.
     :param version: Version number of the model version.
     :return: A single :py:class:`mlflow.entities.model_registry.ModelVersionDetailed` object.
     """
     return self.store.get_model_version_details(
         ModelVersion(RegisteredModel(name), version))
Esempio n. 29
0
    def delete_model_version(self, name, version):
        """
        Delete model version in backend.

        :param name: Name of the containing registered model.
        :param version: Version number of the model version.
        """
        self.store.delete_model_version(
            ModelVersion(RegisteredModel(name), version))
Esempio n. 30
0
def test_create_model_version(mock_store):
    name = "Model 1"
    version = "1"
    mock_store.create_model_version.return_value = ModelVersion(name=name, version=version,
                                                                creation_timestamp=123)
    result = newModelRegistryClient().create_model_version(name, "uri:/for/source", "run123")
    mock_store.create_model_version.assert_called_once_with(name, "uri:/for/source", "run123")
    assert result.name == name
    assert result.version == version