def test_get_latest_versions_with_stages(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.get_latest_versions(registered_model=rm, stages=["blaah"])
     self._verify_requests(
         mock_http, "registered-models/get-latest-versions", "POST",
         GetLatestVersions(registered_model=rm.to_proto(),
                           stages=["blaah"]))
Esempio n. 2
0
    def test_list_registered_model(self):
        self._rm_maker("A")
        registered_models = self.store.list_registered_models()
        self.assertEqual(len(registered_models), 1)
        self.assertEqual(registered_models[0].name, "A")
        self.assertIsInstance(registered_models[0], RegisteredModelDetailed)

        self._rm_maker("B")
        self.assertEqual(
            set([rm.name for rm in self.store.list_registered_models()]),
            set(["A", "B"]))

        self._rm_maker("BB")
        self._rm_maker("BA")
        self._rm_maker("AB")
        self._rm_maker("BBC")
        self.assertEqual(
            set([rm.name for rm in self.store.list_registered_models()]),
            set(["A", "B", "BB", "BA", "AB", "BBC"]))

        # list should not return deleted models
        self.store.delete_registered_model(RegisteredModel("BA"))
        self.store.delete_registered_model(RegisteredModel("B"))
        self.assertEqual(
            set([rm.name for rm in self.store.list_registered_models()]),
            set(["A", "BB", "AB", "BBC"]))
Esempio n. 3
0
def test_search_model_versions(mock_get_request_message, mock_model_registry_store):
    mock_get_request_message.return_value = SearchModelVersions(filter="source_path = 'A/B/CD'")
    mvds = [
        ModelVersionDetailed(RegisteredModel(name="model_1"), version=5, creation_timestamp=100,
                             last_updated_timestamp=1200, description="v 5", user_id="u1",
                             current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex,
                             status="READY", status_message=None),
        ModelVersionDetailed(RegisteredModel(name="model_1"), version=12, creation_timestamp=110,
                             last_updated_timestamp=2000, description="v 12", user_id="u2",
                             current_stage="Production", source="A/B/CD", run_id=uuid.uuid4().hex,
                             status="READY", status_message=None),
        ModelVersionDetailed(RegisteredModel(name="ads_model"), version=8, creation_timestamp=200,
                             last_updated_timestamp=2000, description="v 8", user_id="u1",
                             current_stage="Staging", source="A/B/CD", run_id=uuid.uuid4().hex,
                             status="READY", status_message=None),
        ModelVersionDetailed(RegisteredModel(name="fraud_detection_model"), version=345,
                             creation_timestamp=1000, last_updated_timestamp=1001,
                             description="newest version",  user_id="u12", current_stage="None",
                             source="A/B/CD",  run_id=uuid.uuid4().hex, status="READY",
                             status_message=None),
    ]
    mock_model_registry_store.search_model_versions.return_value = mvds
    resp = _search_model_versions()
    args, _ = mock_model_registry_store.search_model_versions.call_args
    assert args == ("source_path = 'A/B/CD'", )
    assert json.loads(resp.get_data()) == {"model_versions_detailed": jsonify(mvds)}
def test_list_registered_models(mock_store):
    mock_store.list_registered_models.return_value = PagedList(
        [RegisteredModel("Model 1"), RegisteredModel("Model 2")], ""
    )
    result = newModelRegistryClient().list_registered_models()
    mock_store.list_registered_models.assert_called_once()
    assert len(result) == 2
def test_search_registered_models(mock_store):
    mock_store.search_registered_models.return_value = PagedList(
        [RegisteredModel("Model 1"), RegisteredModel("Model 2")], ""
    )
    result = newModelRegistryClient().search_registered_models(filter_string="test filter")
    mock_store.search_registered_models.assert_called_with("test filter", 100, None, None)
    assert len(result) == 2
    assert result.token == ""

    result = newModelRegistryClient().search_registered_models(
        filter_string="another filter",
        max_results=12,
        order_by=["A", "B DESC"],
        page_token="next one",
    )
    mock_store.search_registered_models.assert_called_with(
        "another filter", 12, ["A", "B DESC"], "next one"
    )
    assert len(result) == 2
    assert result.token == ""

    mock_store.search_registered_models.return_value = PagedList(
        [RegisteredModel("model A"), RegisteredModel("Model zz"), RegisteredModel("Model b")],
        "page 2 token",
    )
    result = newModelRegistryClient().search_registered_models(max_results=5)
    mock_store.search_registered_models.assert_called_with(None, 5, None, None)
    assert [rm.name for rm in result] == ["model A", "Model zz", "Model b"]
    assert result.token == "page 2 token"
Esempio n. 6
0
def test_list_registered_models(mock_get_request_message,
                                mock_model_registry_store):
    mock_get_request_message.return_value = ListRegisteredModels(
        max_results=50)
    rmds = PagedList(
        [
            RegisteredModel(
                name="model_1",
                creation_timestamp=111,
                last_updated_timestamp=222,
                description="Test model",
                latest_versions=[],
            ),
            RegisteredModel(
                name="model_2",
                creation_timestamp=111,
                last_updated_timestamp=333,
                description="Another model",
                latest_versions=[],
            ),
        ],
        "next_pt",
    )
    mock_model_registry_store.list_registered_models.return_value = rmds
    resp = _list_registered_models()
    args, _ = mock_model_registry_store.list_registered_models.call_args
    assert args == (50, "")
    assert json.loads(resp.get_data()) == {
        "next_page_token": "next_pt",
        "registered_models": jsonify(rmds),
    }
Esempio n. 7
0
def test_search_registered_models(mock_get_request_message, mock_model_registry_store):
    rmds = [
        RegisteredModel(
            name="model_1",
            creation_timestamp=111,
            last_updated_timestamp=222,
            description="Test model",
            latest_versions=[],
        ),
        RegisteredModel(
            name="model_2",
            creation_timestamp=111,
            last_updated_timestamp=333,
            description="Another model",
            latest_versions=[],
        ),
    ]
    mock_get_request_message.return_value = SearchRegisteredModels()
    mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, None)
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {"filter_string": "", "max_results": 100, "order_by": [], "page_token": ""}
    assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)}

    mock_get_request_message.return_value = SearchRegisteredModels(filter="hello")
    mock_model_registry_store.search_registered_models.return_value = PagedList(rmds[:1], "tok")
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {"filter_string": "hello", "max_results": 100, "order_by": [], "page_token": ""}
    assert json.loads(resp.get_data()) == {
        "registered_models": jsonify(rmds[:1]),
        "next_page_token": "tok",
    }

    mock_get_request_message.return_value = SearchRegisteredModels(filter="hi", max_results=5)
    mock_model_registry_store.search_registered_models.return_value = PagedList([rmds[0]], "tik")
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {"filter_string": "hi", "max_results": 5, "order_by": [], "page_token": ""}
    assert json.loads(resp.get_data()) == {
        "registered_models": jsonify([rmds[0]]),
        "next_page_token": "tik",
    }

    mock_get_request_message.return_value = SearchRegisteredModels(
        filter="hey", max_results=500, order_by=["a", "B desc"], page_token="prev"
    )
    mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, "DONE")
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {
        "filter_string": "hey",
        "max_results": 500,
        "order_by": ["a", "B desc"],
        "page_token": "prev",
    }
    assert json.loads(resp.get_data()) == {
        "registered_models": jsonify(rmds),
        "next_page_token": "DONE",
    }
Esempio n. 8
0
def test_delete_registered_model(mock_get_request_message,
                                 mock_model_registry_store):
    rm = RegisteredModel("model_1")
    mock_get_request_message.return_value = DeleteRegisteredModel(
        registered_model=rm.to_proto())
    _delete_registered_model()
    args, _ = mock_model_registry_store.delete_registered_model.call_args
    assert args == (rm, )
 def test_update_registered_model_name(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.update_registered_model(registered_model=rm,
                                        new_name="model_2")
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(registered_model=rm.to_proto(),
                               name="model_2"))
 def test_update_registered_model_description(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.update_registered_model(registered_model=rm,
                                        description="test model")
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(registered_model=rm.to_proto(),
                               description="test model"))
Esempio n. 11
0
def test_search_model_versions(mock_store):
    mock_store.search_model_versions.return_value = [
        ModelVersion(RegisteredModel("Model 1"), 1),
        ModelVersion(RegisteredModel("Model 1"), 2)
    ]
    result = newModelRegistryClient().search_model_versions("name=Model 1")
    mock_store.search_model_versions.assert_called_once_with("name=Model 1")
    assert len(result) == 2
 def test_update_registered_model_all(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.update_registered_model(registered_model=rm,
                                        new_name="model_3",
                                        description="rename and describe")
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(registered_model=rm.to_proto(),
                               name="model_3",
                               description="rename and describe"))
Esempio n. 13
0
def test_get_latest_versions(mock_get_request_message,
                             mock_model_registry_store):
    rm = RegisteredModel("model1")
    mock_get_request_message.return_value = GetLatestVersions(
        registered_model=rm.to_proto())
    mvds = [
        ModelVersionDetailed(registered_model=rm,
                             version=5,
                             creation_timestamp=1,
                             last_updated_timestamp=12,
                             description="v 5",
                             user_id="u1",
                             current_stage="Production",
                             source="A/B",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
        ModelVersionDetailed(registered_model=rm,
                             version=1,
                             creation_timestamp=1,
                             last_updated_timestamp=1200,
                             description="v 1",
                             user_id="u1",
                             current_stage="Archived",
                             source="A/B2",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
        ModelVersionDetailed(registered_model=rm,
                             version=12,
                             creation_timestamp=100,
                             last_updated_timestamp=None,
                             description="v 12",
                             user_id="u2",
                             current_stage="Staging",
                             source="A/B3",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
    ]
    mock_model_registry_store.get_latest_versions.return_value = mvds
    resp = _get_latest_versions()
    args, _ = mock_model_registry_store.get_latest_versions.call_args
    assert args == (rm, [])
    assert json.loads(resp.get_data()) == {
        "model_versions_detailed": jsonify(mvds)
    }

    for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]:
        mock_get_request_message.return_value = GetLatestVersions(
            registered_model=rm.to_proto(), stages=stages)
        _get_latest_versions()
        args, _ = mock_model_registry_store.get_latest_versions.call_args
        assert args == (rm, stages)
Esempio n. 14
0
def test_register_model_with_non_runs_uri():
    create_model_patch = mock.patch.object(MlflowClient, "create_registered_model",
                                           return_value=RegisteredModel("Model 1"))
    create_version_patch = mock.patch.object(
        MlflowClient, "create_model_version",
        return_value=ModelVersion(RegisteredModel("Model 1"), 1))
    with create_model_patch, create_version_patch:
        register_model("s3:/some/path/to/model", "Model 1")
        MlflowClient.create_registered_model.assert_called_once_with("Model 1")
        MlflowClient.create_model_version.assert_called_once_with("Model 1", run_id=None,
                                                                  source="s3:/some/path/to/model")
Esempio n. 15
0
def test_update_registered_model(mock_get_request_message, mock_model_registry_store):
    rm1 = RegisteredModel("model_1")
    mock_get_request_message.return_value = UpdateRegisteredModel(registered_model=rm1.to_proto(),
                                                                  name="model_2",
                                                                  description="Test model")
    rm2 = RegisteredModel("model_2")
    mock_model_registry_store.update_registered_model.return_value = rm2
    resp = _update_registered_model()
    args, _ = mock_model_registry_store.update_registered_model.call_args
    assert args == (rm1, u"model_2", u"Test model")
    assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
def test_rename_registered_model(mock_store):
    name = "Model 1"
    new_name = "New Name"
    mock_store.rename_registered_model.return_value = RegisteredModel(new_name)
    result = newModelRegistryClient().rename_registered_model(name=name, new_name=new_name)
    mock_store.rename_registered_model.assert_called_with(name=name, new_name=new_name)
    assert result.name == "New Name"

    mock_store.rename_registered_model.return_value = RegisteredModel("New Name 2")
    result = newModelRegistryClient().rename_registered_model(name=name, new_name="New Name 2")
    mock_store.rename_registered_model.assert_called_with(name=name, new_name="New Name 2")
    assert result.name == "New Name 2"
Esempio n. 17
0
def test_get_registered_model_details(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mock_get_request_message.return_value = GetRegisteredModelDetails(
        registered_model=rm.to_proto())
    rmd = RegisteredModelDetailed(name="model_1", creation_timestamp=111,
                                  last_updated_timestamp=222, description="Test model",
                                  latest_versions=[])
    mock_model_registry_store.get_registered_model_details.return_value = rmd
    resp = _get_registered_model_details()
    args, _ = mock_model_registry_store.get_registered_model_details.call_args
    assert args == (rm, )
    assert json.loads(resp.get_data()) == {"registered_model_detailed": jsonify(rmd)}
Esempio n. 18
0
def test_register_model_with_runs_uri():
    create_model_patch = mock.patch.object(MlflowClient, "create_registered_model",
                                           return_value=RegisteredModel("Model 1"))
    get_uri_patch = mock.patch(
        "mlflow.store.artifact.runs_artifact_repo.RunsArtifactRepository.get_underlying_uri",
        return_value="s3:/path/to/source")
    create_version_patch = mock.patch.object(
        MlflowClient, "create_model_version",
        return_value=ModelVersion(RegisteredModel("Model 1"), 1))
    with get_uri_patch, create_model_patch, create_version_patch:
        register_model("runs:/run12345/path/to/model", "Model 1")
        MlflowClient.create_registered_model.assert_called_once_with("Model 1")
        MlflowClient.create_model_version.assert_called_once_with("Model 1", "s3:/path/to/source",
                                                                  "run12345")
Esempio n. 19
0
def test_list_registered_models(mock_get_request_message, mock_model_registry_store):
    mock_get_request_message.return_value = ListRegisteredModels()
    rmds = [
        RegisteredModel(name="model_1", creation_timestamp=111,
                        last_updated_timestamp=222, description="Test model",
                        latest_versions=[]),
        RegisteredModel(name="model_2", creation_timestamp=111,
                        last_updated_timestamp=333, description="Another model",
                        latest_versions=[]),
    ]
    mock_model_registry_store.list_registered_models.return_value = rmds
    resp = _list_registered_models()
    args, _ = mock_model_registry_store.list_registered_models.call_args
    assert args == ()
    assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)}
Esempio n. 20
0
def _get_registered_model_details():
    request_message = _get_request_message(GetRegisteredModelDetails())
    registered_model_detailed = _get_model_registry_store().get_registered_model_details(
        RegisteredModel.from_proto(request_message.registered_model))
    response_message = GetRegisteredModelDetails.Response(
        registered_model_detailed=registered_model_detailed.to_proto())
    return _wrap_response(response_message)
 def test_get_model_version_download_uri(self, mock_http):
     rm = RegisteredModel("model_11")
     mv = ModelVersion(rm, 8)
     self.store.get_model_version_download_uri(model_version=mv)
     self._verify_requests(
         mock_http, "model-versions/get-download-uri", "POST",
         GetModelVersionDownloadUri(model_version=mv.to_proto()))
Esempio n. 22
0
def _get_latest_versions():
    request_message = _get_request_message(GetLatestVersions())
    latest_versions = _get_model_registry_store().get_latest_versions(
        RegisteredModel.from_proto(request_message.registered_model), request_message.stages)
    response_message = GetLatestVersions.Response()
    response_message.model_versions_detailed.extend([e.to_proto() for e in latest_versions])
    return _wrap_response(response_message)
Esempio n. 23
0
    def search_registered_models(self,
                                 filter_string=None,
                                 max_results=None,
                                 order_by=None,
                                 page_token=None):
        """
        Search for registered models in backend that satisfy the filter criteria.

        :param filter_string: Filter query string, defaults to searching all registered models.
        :param max_results: Maximum number of registered models desired.
        :param order_by: List of column names with ASC|DESC annotation, to be used for ordering
                         matching search results.
        :param page_token: Token specifying the next page of results. It should be obtained from
                            a ``search_registered_models`` call.
        :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
                that satisfy the search expressions. The pagination token for the next page can be
                obtained via the ``token`` attribute of the object.
        """
        req_body = message_to_json(
            SearchRegisteredModels(filter=filter_string,
                                   max_results=max_results,
                                   order_by=order_by,
                                   page_token=page_token))
        response_proto = self._call_endpoint(SearchRegisteredModels, req_body)
        registered_models = [
            RegisteredModel.from_proto(registered_model)
            for registered_model in response_proto.registered_models
        ]
        return PagedList(registered_models, response_proto.next_page_token)
Esempio n. 24
0
def test_delete_model_version(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mv = ModelVersion(registered_model=rm, version=32)
    mock_get_request_message.return_value = DeleteModelVersion(model_version=mv.to_proto())
    _delete_model_version()
    args, _ = mock_model_registry_store.delete_model_version.call_args
    assert args == (mv, )
 def test_get_model_version_stages(self, mock_http):
     rm = RegisteredModel("model_11")
     mv = ModelVersion(rm, 8)
     self.store.get_model_version_stages(model_version=mv)
     self._verify_requests(
         mock_http, "model-versions/get-stages", "POST",
         GetModelVersionStages(model_version=mv.to_proto()))
 def test_update_model_version_stage(self, mock_http):
     rm = RegisteredModel("model_1")
     mv = ModelVersion(rm, 5)
     self.store.update_model_version(model_version=mv, stage="prod")
     self._verify_requests(
         mock_http, "model-versions/update", "PATCH",
         UpdateModelVersion(model_version=mv.to_proto(), stage="prod"))
Esempio n. 27
0
def test_create_registered_model(mock_get_request_message, mock_model_registry_store):
    mock_get_request_message.return_value = CreateRegisteredModel(name="model_1")
    rm = RegisteredModel("model_1")
    mock_model_registry_store.create_registered_model.return_value = rm
    resp = _create_registered_model()
    _, args = mock_model_registry_store.create_registered_model.call_args
    assert args == {"name": "model_1"}
    assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm)}
def test_create_registered_model(mock_store):
    tags_dict = {"key": "value", "another key": "some other value"}
    tags = [RegisteredModelTag(key, value) for key, value in tags_dict.items()]
    mock_store.create_registered_model.return_value = RegisteredModel("Model 1", tags=tags)
    result = newModelRegistryClient().create_registered_model("Model 1", tags_dict)
    mock_store.create_registered_model.assert_called_once_with("Model 1", tags)
    assert result.name == "Model 1"
    assert result.tags == tags_dict
Esempio n. 29
0
    def delete_registered_model(self, name):
        """
        Delete registered model.
        Backend raises exception if a registered model with given name does not exist.

        :param name: Name of the registered model to update.
        """
        self.store.delete_registered_model(RegisteredModel(name))
Esempio n. 30
0
 def get_model_version_details(self, name, version):
     """
     :param name: Name of the containing registered model.
     :param version: Version number of the model version.
     :return: A single :py:class:`mlflow.entities.model_registry.ModelVersionDetailed` object.
     """
     return self.store.get_model_version_details(
         ModelVersion(RegisteredModel(name), version))