示例#1
0
def test_get_latest_versions(mock_get_request_message,
                             mock_model_registry_store):
    rm = RegisteredModel("model1")
    mock_get_request_message.return_value = GetLatestVersions(
        registered_model=rm.to_proto())
    mvds = [
        ModelVersionDetailed(registered_model=rm,
                             version=5,
                             creation_timestamp=1,
                             last_updated_timestamp=12,
                             description="v 5",
                             user_id="u1",
                             current_stage="Production",
                             source="A/B",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
        ModelVersionDetailed(registered_model=rm,
                             version=1,
                             creation_timestamp=1,
                             last_updated_timestamp=1200,
                             description="v 1",
                             user_id="u1",
                             current_stage="Archived",
                             source="A/B2",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
        ModelVersionDetailed(registered_model=rm,
                             version=12,
                             creation_timestamp=100,
                             last_updated_timestamp=None,
                             description="v 12",
                             user_id="u2",
                             current_stage="Staging",
                             source="A/B3",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
    ]
    mock_model_registry_store.get_latest_versions.return_value = mvds
    resp = _get_latest_versions()
    args, _ = mock_model_registry_store.get_latest_versions.call_args
    assert args == (rm, [])
    assert json.loads(resp.get_data()) == {
        "model_versions_detailed": jsonify(mvds)
    }

    for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]:
        mock_get_request_message.return_value = GetLatestVersions(
            registered_model=rm.to_proto(), stages=stages)
        _get_latest_versions()
        args, _ = mock_model_registry_store.get_latest_versions.call_args
        assert args == (rm, stages)
 def test_get_latest_versions_with_stages(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.get_latest_versions(registered_model=rm, stages=["blaah"])
     self._verify_requests(
         mock_http, "registered-models/get-latest-versions", "POST",
         GetLatestVersions(registered_model=rm.to_proto(),
                           stages=["blaah"]))
示例#3
0
def test_delete_registered_model(mock_get_request_message,
                                 mock_model_registry_store):
    rm = RegisteredModel("model_1")
    mock_get_request_message.return_value = DeleteRegisteredModel(
        registered_model=rm.to_proto())
    _delete_registered_model()
    args, _ = mock_model_registry_store.delete_registered_model.call_args
    assert args == (rm, )
 def test_update_registered_model_description(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.update_registered_model(registered_model=rm,
                                        description="test model")
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(registered_model=rm.to_proto(),
                               description="test model"))
 def test_update_registered_model_name(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.update_registered_model(registered_model=rm,
                                        new_name="model_2")
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(registered_model=rm.to_proto(),
                               name="model_2"))
 def test_update_registered_model_all(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.update_registered_model(registered_model=rm,
                                        new_name="model_3",
                                        description="rename and describe")
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(registered_model=rm.to_proto(),
                               name="model_3",
                               description="rename and describe"))
示例#7
0
def test_update_registered_model(mock_get_request_message, mock_model_registry_store):
    rm1 = RegisteredModel("model_1")
    mock_get_request_message.return_value = UpdateRegisteredModel(registered_model=rm1.to_proto(),
                                                                  name="model_2",
                                                                  description="Test model")
    rm2 = RegisteredModel("model_2")
    mock_model_registry_store.update_registered_model.return_value = rm2
    resp = _update_registered_model()
    args, _ = mock_model_registry_store.update_registered_model.call_args
    assert args == (rm1, u"model_2", u"Test model")
    assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
示例#8
0
def test_get_registered_model_details(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mock_get_request_message.return_value = GetRegisteredModelDetails(
        registered_model=rm.to_proto())
    rmd = RegisteredModelDetailed(name="model_1", creation_timestamp=111,
                                  last_updated_timestamp=222, description="Test model",
                                  latest_versions=[])
    mock_model_registry_store.get_registered_model_details.return_value = rmd
    resp = _get_registered_model_details()
    args, _ = mock_model_registry_store.get_registered_model_details.call_args
    assert args == (rm, )
    assert json.loads(resp.get_data()) == {"registered_model_detailed": jsonify(rmd)}
 def test_delete_registered_model(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.delete_registered_model(registered_model=rm)
     self._verify_requests(
         mock_http, "registered-models/delete", "DELETE",
         DeleteRegisteredModel(registered_model=rm.to_proto()))
 def test_get_registered_model_detailed(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.get_registered_model_details(registered_model=rm)
     self._verify_requests(
         mock_http, "registered-models/get-details", "POST",
         GetRegisteredModelDetails(registered_model=rm.to_proto()))