Beispiel #1
0
def test_search_model_versions(mock_store):
    mock_store.search_model_versions.return_value = [
        ModelVersion(name="Model 1", version="1", creation_timestamp=123),
        ModelVersion(name="Model 1", version="2", creation_timestamp=124)
    ]
    result = newModelRegistryClient().search_model_versions("name=Model 1")
    mock_store.search_model_versions.assert_called_once_with("name=Model 1")
    assert len(result) == 2
Beispiel #2
0
def test_search_model_versions(mock_get_request_message,
                               mock_model_registry_store):
    mock_get_request_message.return_value = SearchModelVersions(
        filter="source_path = 'A/B/CD'")
    mvds = [
        ModelVersion(name="model_1",
                     version="5",
                     creation_timestamp=100,
                     last_updated_timestamp=1200,
                     description="v 5",
                     user_id="u1",
                     current_stage="Production",
                     source="A/B/CD",
                     run_id=uuid.uuid4().hex,
                     status="READY",
                     status_message=None),
        ModelVersion(name="model_1",
                     version="12",
                     creation_timestamp=110,
                     last_updated_timestamp=2000,
                     description="v 12",
                     user_id="u2",
                     current_stage="Production",
                     source="A/B/CD",
                     run_id=uuid.uuid4().hex,
                     status="READY",
                     status_message=None),
        ModelVersion(name="ads_model",
                     version="8",
                     creation_timestamp=200,
                     last_updated_timestamp=2000,
                     description="v 8",
                     user_id="u1",
                     current_stage="Staging",
                     source="A/B/CD",
                     run_id=uuid.uuid4().hex,
                     status="READY",
                     status_message=None),
        ModelVersion(name="fraud_detection_model",
                     version="345",
                     creation_timestamp=1000,
                     last_updated_timestamp=1001,
                     description="newest version",
                     user_id="u12",
                     current_stage="None",
                     source="A/B/CD",
                     run_id=uuid.uuid4().hex,
                     status="READY",
                     status_message=None),
    ]
    mock_model_registry_store.search_model_versions.return_value = mvds
    resp = _search_model_versions()
    args, _ = mock_model_registry_store.search_model_versions.call_args
    assert args == ("source_path = 'A/B/CD'", )
    assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)}
Beispiel #3
0
def test_get_latest_versions(mock_get_request_message,
                             mock_model_registry_store):
    name = "model1"
    mock_get_request_message.return_value = GetLatestVersions(name=name)
    mvds = [
        ModelVersion(name=name,
                     version="5",
                     creation_timestamp=1,
                     last_updated_timestamp=12,
                     description="v 5",
                     user_id="u1",
                     current_stage="Production",
                     source="A/B",
                     run_id=uuid.uuid4().hex,
                     status="READY",
                     status_message=None),
        ModelVersion(name=name,
                     version="1",
                     creation_timestamp=1,
                     last_updated_timestamp=1200,
                     description="v 1",
                     user_id="u1",
                     current_stage="Archived",
                     source="A/B2",
                     run_id=uuid.uuid4().hex,
                     status="READY",
                     status_message=None),
        ModelVersion(name=name,
                     version="12",
                     creation_timestamp=100,
                     last_updated_timestamp=None,
                     description="v 12",
                     user_id="u2",
                     current_stage="Staging",
                     source="A/B3",
                     run_id=uuid.uuid4().hex,
                     status="READY",
                     status_message=None),
    ]
    mock_model_registry_store.get_latest_versions.return_value = mvds
    resp = _get_latest_versions()
    _, args = mock_model_registry_store.get_latest_versions.call_args
    assert args == {"name": name, "stages": []}
    assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)}

    for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]:
        mock_get_request_message.return_value = GetLatestVersions(
            name=name, stages=stages)
        _get_latest_versions()
        _, args = mock_model_registry_store.get_latest_versions.call_args
        assert args == {"name": name, "stages": stages}
Beispiel #4
0
 def to_mlflow_entity(self):
     return ModelVersion(
         self.name, self.version, self.creation_time,
         self.last_updated_time, self.description, self.user_id,
         self.current_stage, self.source, self.run_id, self.status,
         self.status_message,
         [tag.to_mlflow_entity() for tag in self.model_version_tags])
Beispiel #5
0
def test_create_model_version(mock_get_request_message,
                              mock_model_registry_store):
    run_id = uuid.uuid4().hex
    tags = [
        ModelVersionTag(key="key", value="value"),
        ModelVersionTag(key="anotherKey", value="some other value")
    ]
    mock_get_request_message.return_value = CreateModelVersion(
        name="model_1",
        source="A/B",
        run_id=run_id,
        tags=[tag.to_proto() for tag in tags])
    mv = ModelVersion(name="model_1",
                      version="12",
                      creation_timestamp=123,
                      tags=tags)
    mock_model_registry_store.create_model_version.return_value = mv
    resp = _create_model_version()
    _, args = mock_model_registry_store.create_model_version.call_args
    assert args["name"] == "model_1"
    assert args["source"] == "A/B"
    assert args["run_id"] == run_id
    assert {tag.key: tag.value
            for tag in args["tags"]} == {tag.key: tag.value
                                         for tag in tags}
    assert json.loads(resp.get_data()) == {"model_version": jsonify(mv)}
Beispiel #6
0
    def get_model_version(self, name, version):
        """
        Get the model version instance by name and version.

        :param name: Registered model name.
        :param version: Registered model version.
        :return: A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
        """
        req_body = message_to_json(
            GetModelVersion(name=name, version=str(version)))
        response_proto = self._call_endpoint(GetModelVersion, req_body)
        return ModelVersion.from_proto(response_proto.model_version)
Beispiel #7
0
def test_update_model_version(mock_store):
    name = "Model 1"
    version = "12"
    description = "new description"
    expected_result = ModelVersion(name,
                                   version,
                                   creation_timestamp=123,
                                   description=description)
    mock_store.update_model_version.return_value = expected_result
    actal_result = newModelRegistryClient().update_model_version(
        name, version, "new description")
    mock_store.update_model_version.assert_called_once_with(
        name=name, version=version, description="new description")
    assert expected_result == actal_result
Beispiel #8
0
def test_create_model_version(mock_store):
    name = "Model 1"
    version = "1"
    tags_dict = {"key": "value", "another key": "some other value"}
    tags = [ModelVersionTag(key, value) for key, value in tags_dict.items()]
    mock_store.create_model_version.return_value = ModelVersion(
        name=name, version=version, creation_timestamp=123, tags=tags)
    result = newModelRegistryClient().create_model_version(
        name, "uri:/for/source", "run123", tags_dict)
    mock_store.create_model_version.assert_called_once_with(
        name, "uri:/for/source", "run123", tags)
    assert result.name == name
    assert result.version == version
    assert result.tags == tags_dict
Beispiel #9
0
    def update_model_version(self, name, version, description):
        """
        Update metadata associated with a model version in backend.

        :param name: Registered model name.
        :param version: Registered model version.
        :param description: New model description.
        :return: A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
        """
        req_body = message_to_json(
            UpdateModelVersion(name=name,
                               version=str(version),
                               description=description))
        response_proto = self._call_endpoint(UpdateModelVersion, req_body)
        return ModelVersion.from_proto(response_proto.model_version)
Beispiel #10
0
def _model_version(name,
                   version,
                   stage,
                   source="some:/source",
                   run_id="run13579",
                   tags=None):
    return ModelVersion(name,
                        version,
                        "2345671890",
                        "234567890",
                        "some description",
                        "UserID",
                        stage,
                        source,
                        run_id,
                        tags=tags)
Beispiel #11
0
def test_register_model_with_non_runs_uri():
    create_model_patch = mock.patch.object(
        MlflowClient,
        "create_registered_model",
        return_value=RegisteredModel("Model 1"))
    create_version_patch = mock.patch.object(MlflowClient,
                                             "create_model_version",
                                             return_value=ModelVersion(
                                                 "Model 1",
                                                 "1",
                                                 creation_timestamp=123))
    with create_model_patch, create_version_patch:
        register_model("s3:/some/path/to/model", "Model 1")
        MlflowClient.create_registered_model.assert_called_once_with("Model 1")
        MlflowClient.create_model_version.assert_called_once_with(
            "Model 1", run_id=None, source="s3:/some/path/to/model")
Beispiel #12
0
    def get_latest_versions(self, name, stages=None):
        """
        Latest version models for each requested stage. If no ``stages`` argument is provided,
        returns the latest version for each stage.

        :param name: Registered model name.
        :param stages: List of desired stages. If input list is None, return latest versions for
                       for 'Staging' and 'Production' stages.
        :return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
        """
        req_body = message_to_json(GetLatestVersions(name=name, stages=stages))
        response_proto = self._call_endpoint(GetLatestVersions, req_body)
        return [
            ModelVersion.from_proto(model_version)
            for model_version in response_proto.model_versions
        ]
Beispiel #13
0
    def search_model_versions(self, filter_string):
        """
        Search for model versions in backend that satisfy the filter criteria.

        :param filter_string: A filter string expression. Currently supports a single filter
                              condition either name of model like ``name = 'model_name'`` or
                              ``run_id = '...'``.
        :return: PagedList of :py:class:`mlflow.entities.model_registry.ModelVersion`
                 objects.
        """
        req_body = message_to_json(SearchModelVersions(filter=filter_string))
        response_proto = self._call_endpoint(SearchModelVersions, req_body)
        model_versions = [
            ModelVersion.from_proto(mvd)
            for mvd in response_proto.model_versions
        ]
        return PagedList(model_versions, response_proto.next_page_token)
Beispiel #14
0
def test_register_model_with_existing_registered_model():
    create_model_patch = mock.patch.object(MlflowClient,
                                           "create_registered_model",
                                           side_effect=MlflowException(
                                               "Some Message",
                                               RESOURCE_ALREADY_EXISTS))
    create_version_patch = mock.patch.object(MlflowClient,
                                             "create_model_version",
                                             return_value=ModelVersion(
                                                 "Model 1",
                                                 "1",
                                                 creation_timestamp=123))
    with create_model_patch, create_version_patch:
        register_model("s3:/some/path/to/model", "Model 1")
        MlflowClient.create_registered_model.assert_called_once_with("Model 1")
        MlflowClient.create_model_version.assert_called_once_with(
            "Model 1", run_id=None, source="s3:/some/path/to/model")
Beispiel #15
0
def test_transition_model_version_stage(mock_store):
    name = "Model 1"
    version = "12"
    stage = "Production"
    expected_result = ModelVersion(name,
                                   version,
                                   creation_timestamp=123,
                                   current_stage=stage)
    mock_store.transition_model_version_stage.return_value = expected_result
    actual_result = newModelRegistryClient().transition_model_version_stage(
        name, version, stage)
    mock_store.transition_model_version_stage.assert_called_once_with(
        name=name,
        version=version,
        stage=stage,
        archive_existing_versions=False)
    assert expected_result == actual_result
Beispiel #16
0
def test_register_model_with_runs_uri():
    create_model_patch = mock.patch.object(
        MlflowClient,
        "create_registered_model",
        return_value=RegisteredModel("Model 1"))
    get_uri_patch = mock.patch(
        "mlflow.store.artifact.runs_artifact_repo.RunsArtifactRepository.get_underlying_uri",
        return_value="s3:/path/to/source")
    create_version_patch = mock.patch.object(MlflowClient,
                                             "create_model_version",
                                             return_value=ModelVersion(
                                                 "Model 1",
                                                 "1",
                                                 creation_timestamp=123))
    with get_uri_patch, create_model_patch, create_version_patch:
        register_model("runs:/run12345/path/to/model", "Model 1")
        MlflowClient.create_registered_model.assert_called_once_with("Model 1")
        MlflowClient.create_model_version.assert_called_once_with(
            "Model 1", "s3:/path/to/source", "run12345")
Beispiel #17
0
    def create_model_version(self, name, source, run_id, tags=None):
        """
        Create a new model version from given source and run ID.

        :param name: Registered model name.
        :param source: Source path where the MLflow model is stored.
        :param run_id: Run ID from MLflow tracking server that generated the model.
        :param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
                     instances associated with this model version.
        :return: A single object of :py:class:`mlflow.entities.model_registry.ModelVersion`
                 created in the backend.
        """
        proto_tags = [tag.to_proto() for tag in tags or []]
        req_body = message_to_json(
            CreateModelVersion(name=name,
                               source=source,
                               run_id=run_id,
                               tags=proto_tags))
        response_proto = self._call_endpoint(CreateModelVersion, req_body)
        return ModelVersion.from_proto(response_proto.model_version)
Beispiel #18
0
def test_transition_model_version_stage(mock_get_request_message,
                                        mock_model_registry_store):
    name = "model1"
    version = "32"
    stage = "Production"
    mock_get_request_message.return_value = TransitionModelVersionStage(
        name=name, version=version, stage=stage)
    mv = ModelVersion(name=name,
                      version=version,
                      creation_timestamp=123,
                      current_stage=stage)
    mock_model_registry_store.transition_model_version_stage.return_value = mv
    _transition_stage()
    _, args = mock_model_registry_store.transition_model_version_stage.call_args
    assert args == {
        "name": name,
        "version": version,
        "stage": stage,
        "archive_existing_versions": False
    }
Beispiel #19
0
def test_get_model_version_details(mock_get_request_message,
                                   mock_model_registry_store):
    mock_get_request_message.return_value = GetModelVersion(name="model1",
                                                            version="32")
    mvd = ModelVersion(name="model1",
                       version="5",
                       creation_timestamp=1,
                       last_updated_timestamp=12,
                       description="v 5",
                       user_id="u1",
                       current_stage="Production",
                       source="A/B",
                       run_id=uuid.uuid4().hex,
                       status="READY",
                       status_message=None)
    mock_model_registry_store.get_model_version.return_value = mvd
    resp = _get_model_version()
    _, args = mock_model_registry_store.get_model_version.call_args
    assert args == {"name": "model1", "version": "32"}
    assert json.loads(resp.get_data()) == {"model_version": jsonify(mvd)}
Beispiel #20
0
def test_models_artifact_repo_init_with_stage_uri(host_creds_mock):  # pylint: disable=unused-argument
    model_uri = "models:/MyModel/Production"
    artifact_location = "dbfs://databricks/mlflow-registry/12345/models/keras-model"
    model_version_detailed = ModelVersion("MyModel", "10", "2345671890",
                                          "234567890", "some description",
                                          "UserID", "Production", "source",
                                          "run12345")
    get_latest_versions_patch = mock.patch.object(
        MlflowClient,
        "get_latest_versions",
        return_value=[model_version_detailed])
    get_model_version_download_uri_patch = mock.patch.object(
        MlflowClient,
        "get_model_version_download_uri",
        return_value=artifact_location)
    with get_latest_versions_patch, get_model_version_download_uri_patch:
        models_repo = ModelsArtifactRepository(model_uri)
        assert models_repo.artifact_uri == model_uri
        assert isinstance(models_repo.repo, DbfsRestArtifactRepository)
        assert models_repo.repo.artifact_uri == artifact_location
Beispiel #21
0
def test_update_model_version(mock_get_request_message,
                              mock_model_registry_store):
    name = "model1"
    version = "32"
    description = "Great model!"
    mock_get_request_message.return_value = UpdateModelVersion(
        name=name, version=version, description=description)

    mv = ModelVersion(name=name,
                      version=version,
                      creation_timestamp=123,
                      description=description)
    mock_model_registry_store.update_model_version.return_value = mv
    _update_model_version()
    _, args = mock_model_registry_store.update_model_version.call_args
    assert args == {
        "name": name,
        "version": version,
        "description": description
    }
Beispiel #22
0
    def transition_model_version_stage(self, name, version, stage,
                                       archive_existing_versions):
        """
        Update model version stage.

        :param name: Registered model name.
        :param version: Registered model version.
        :param new_stage: New desired stage for this model version.
        :param archive_existing_versions: If this flag is set to ``True``, all existing model
            versions in the stage will be automically moved to the "archived" stage. Only valid
            when ``stage`` is ``"staging"`` or ``"production"`` otherwise an error will be raised.

        :return: A single :py:class:`mlflow.entities.model_registry.ModelVersion` object.
        """
        req_body = message_to_json(
            TransitionModelVersionStage(
                name=name,
                version=str(version),
                stage=stage,
                archive_existing_versions=archive_existing_versions))
        response_proto = self._call_endpoint(TransitionModelVersionStage,
                                             req_body)
        return ModelVersion.from_proto(response_proto.model_version)