class TestRestStore(unittest.TestCase): def setUp(self): self.creds = MlflowHostCreds('https://hello') self.store = RestStore(lambda: self.creds) def tearDown(self): pass def _args(self, host_creds, endpoint, method, json_body): res = { 'host_creds': host_creds, 'endpoint': "/api/2.0/preview/mlflow/%s" % endpoint, 'method': method } if method == "GET": res["params"] = json.loads(json_body) else: res["json"] = json.loads(json_body) return res def _verify_requests(self, http_request, endpoint, method, proto_message): print(http_request.call_args_list) json_body = message_to_json(proto_message) http_request.assert_any_call( **(self._args(self.creds, endpoint, method, json_body))) @mock.patch('mlflow.utils.rest_utils.http_request') def test_create_registered_model(self, mock_http): self.store.create_registered_model("model_1") self._verify_requests(mock_http, "registered-models/create", "POST", CreateRegisteredModel(name="model_1")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_update_registered_model_name(self, mock_http): name = "model_1" new_name = "model_2" self.store.rename_registered_model(name=name, new_name=new_name) self._verify_requests( mock_http, "registered-models/rename", "POST", RenameRegisteredModel(name=name, new_name=new_name)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_update_registered_model_description(self, mock_http): name = "model_1" description = "test model" self.store.update_registered_model(name=name, description=description) self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(name=name, description=description)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_delete_registered_model(self, mock_http): name = "model_1" self.store.delete_registered_model(name=name) self._verify_requests(mock_http, "registered-models/delete", "DELETE", DeleteRegisteredModel(name=name)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_list_registered_model(self, mock_http): self.store.list_registered_models() self._verify_requests(mock_http, "registered-models/list", "GET", ListRegisteredModels()) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_registered_model(self, mock_http): name = "model_1" self.store.get_registered_model(name=name) self._verify_requests(mock_http, "registered-models/get", "GET", GetRegisteredModel(name=name)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_latest_versions(self, mock_http): name = "model_1" self.store.get_latest_versions(name=name) self._verify_requests(mock_http, "registered-models/get-latest-versions", "GET", GetLatestVersions(name=name)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_latest_versions_with_stages(self, mock_http): name = "model_1" self.store.get_latest_versions(name=name, stages=["blaah"]) self._verify_requests(mock_http, "registered-models/get-latest-versions", "GET", GetLatestVersions(name=name, stages=["blaah"])) @mock.patch('mlflow.utils.rest_utils.http_request') def test_create_model_version(self, mock_http): run_id = uuid.uuid4().hex self.store.create_model_version("model_1", "path/to/source", run_id) self._verify_requests( mock_http, "model-versions/create", "POST", CreateModelVersion(name="model_1", source="path/to/source", run_id=run_id)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_transition_model_version_stage(self, mock_http): name = "model_1" version = "5" self.store.transition_model_version_stage( name=name, version=version, stage="prod", archive_existing_versions=True) self._verify_requests( mock_http, "model-versions/transition-stage", "POST", TransitionModelVersionStage(name=name, version=version, stage="prod", archive_existing_versions=True)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_update_model_version_decription(self, mock_http): name = "model_1" version = "5" description = "test model version" self.store.update_model_version(name=name, version=version, description=description) self._verify_requests( mock_http, "model-versions/update", "PATCH", UpdateModelVersion(name=name, version=version, description="test model version")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_delete_model_version(self, mock_http): name = "model_1" version = "12" self.store.delete_model_version(name=name, version=version) self._verify_requests(mock_http, "model-versions/delete", "DELETE", DeleteModelVersion(name=name, version=version)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_model_version_details(self, mock_http): name = "model_11" version = "8" self.store.get_model_version(name=name, version=version) self._verify_requests(mock_http, "model-versions/get", "GET", GetModelVersion(name=name, version=version)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_model_version_download_uri(self, mock_http): name = "model_11" version = "8" self.store.get_model_version_download_uri(name=name, version=version) self._verify_requests( mock_http, "model-versions/get-download-uri", "GET", GetModelVersionDownloadUri(name=name, version=version)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_search_model_versions(self, mock_http): self.store.search_model_versions(filter_string="name='model_12'") self._verify_requests(mock_http, "model-versions/search", "GET", SearchModelVersions(filter="name='model_12'"))
class TestRestStore(unittest.TestCase): def setUp(self): self.creds = MlflowHostCreds("https://hello") self.store = RestStore(lambda: self.creds) def tearDown(self): pass def _args(self, host_creds, endpoint, method, json_body, preview=True): res = { "host_creds": host_creds, "endpoint": ("/api/2.0/preview/mlflow/%s" if preview else "/api/2.0/mlflow/%s") % endpoint, "method": method, } if method == "GET": res["params"] = json.loads(json_body) else: res["json"] = json.loads(json_body) return res def _verify_requests(self, http_request, endpoint, method, proto_message): json_body = message_to_json(proto_message) http_request.assert_any_call( **(self._args(self.creds, endpoint, method, json_body))) def _verify_all_requests(self, http_request, endpoints, proto_message): json_body = message_to_json(proto_message) http_request.assert_has_calls([ mock.call(**(self._args( self.creds, endpoint, method, json_body, preview=preview))) for endpoint, method, preview in endpoints ]) @mock_http_request def test_create_registered_model(self, mock_http): tags = [ RegisteredModelTag(key="key", value="value"), RegisteredModelTag(key="anotherKey", value="some other value"), ] description = "best model ever" self.store.create_registered_model("model_1", tags, description) self._verify_requests( mock_http, "registered-models/create", "POST", CreateRegisteredModel(name="model_1", tags=[tag.to_proto() for tag in tags], description=description), ) @mock_http_request def test_update_registered_model_name(self, mock_http): name = "model_1" new_name = "model_2" self.store.rename_registered_model(name=name, new_name=new_name) self._verify_requests( mock_http, "registered-models/rename", "POST", RenameRegisteredModel(name=name, new_name=new_name), ) @mock_http_request def test_update_registered_model_description(self, mock_http): name = "model_1" description = "test model" self.store.update_registered_model(name=name, description=description) self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(name=name, description=description), ) @mock_http_request def test_delete_registered_model(self, mock_http): name = "model_1" self.store.delete_registered_model(name=name) self._verify_requests(mock_http, "registered-models/delete", "DELETE", DeleteRegisteredModel(name=name)) @mock_http_request def test_list_registered_model(self, mock_http): self.store.list_registered_models(max_results=50, page_token=None) self._verify_requests( mock_http, "registered-models/list", "GET", ListRegisteredModels(page_token=None, max_results=50), ) @mock_http_request def test_search_registered_model(self, mock_http): self.store.search_registered_models() self._verify_requests(mock_http, "registered-models/search", "GET", SearchRegisteredModels()) params_list = [ { "filter_string": "model = 'yo'" }, { "max_results": 400 }, { "page_token": "blah" }, { "order_by": ["x", "Y"] }, ] # test all combination of params for sz in [0, 1, 2, 3, 4]: for combination in combinations(params_list, sz): params = {k: v for d in combination for k, v in d.items()} self.store.search_registered_models(**params) if "filter_string" in params: params["filter"] = params.pop("filter_string") self._verify_requests(mock_http, "registered-models/search", "GET", SearchRegisteredModels(**params)) @mock_http_request def test_get_registered_model(self, mock_http): name = "model_1" self.store.get_registered_model(name=name) self._verify_requests(mock_http, "registered-models/get", "GET", GetRegisteredModel(name=name)) @mock_multiple_http_requests def test_get_latest_versions(self, mock_multiple_http_requests): name = "model_1" self.store.get_latest_versions(name=name) endpoint = "registered-models/get-latest-versions" endpoints = [(endpoint, "POST", False), (endpoint, "GET", True)] self._verify_all_requests(mock_multiple_http_requests, endpoints, GetLatestVersions(name=name)) @mock_multiple_http_requests def test_get_latest_versions_with_stages(self, mock_multiple_http_requests): name = "model_1" self.store.get_latest_versions(name=name, stages=["blaah"]) endpoint = "registered-models/get-latest-versions" endpoints = [(endpoint, "POST", False), (endpoint, "GET", True)] self._verify_all_requests( mock_multiple_http_requests, endpoints, GetLatestVersions(name=name, stages=["blaah"])) @mock_http_request def test_set_registered_model_tag(self, mock_http): name = "model_1" tag = RegisteredModelTag(key="key", value="value") self.store.set_registered_model_tag(name=name, tag=tag) self._verify_requests( mock_http, "registered-models/set-tag", "POST", SetRegisteredModelTag(name=name, key=tag.key, value=tag.value), ) @mock_http_request def test_delete_registered_model_tag(self, mock_http): name = "model_1" self.store.delete_registered_model_tag(name=name, key="key") self._verify_requests( mock_http, "registered-models/delete-tag", "DELETE", DeleteRegisteredModelTag(name=name, key="key"), ) @mock_http_request def test_create_model_version(self, mock_http): self.store.create_model_version("model_1", "path/to/source") self._verify_requests( mock_http, "model-versions/create", "POST", CreateModelVersion(name="model_1", source="path/to/source"), ) # test optional fields run_id = uuid.uuid4().hex tags = [ ModelVersionTag(key="key", value="value"), ModelVersionTag(key="anotherKey", value="some other value"), ] run_link = "localhost:5000/path/to/run" description = "version description" self.store.create_model_version( "model_1", "path/to/source", run_id, tags, run_link=run_link, description=description, ) self._verify_requests( mock_http, "model-versions/create", "POST", CreateModelVersion( name="model_1", source="path/to/source", run_id=run_id, run_link=run_link, tags=[tag.to_proto() for tag in tags], description=description, ), ) @mock_http_request def test_transition_model_version_stage(self, mock_http): name = "model_1" version = "5" self.store.transition_model_version_stage( name=name, version=version, stage="prod", archive_existing_versions=True) self._verify_requests( mock_http, "model-versions/transition-stage", "POST", TransitionModelVersionStage(name=name, version=version, stage="prod", archive_existing_versions=True), ) @mock_http_request def test_update_model_version_decription(self, mock_http): name = "model_1" version = "5" description = "test model version" self.store.update_model_version(name=name, version=version, description=description) self._verify_requests( mock_http, "model-versions/update", "PATCH", UpdateModelVersion(name=name, version=version, description="test model version"), ) @mock_http_request def test_delete_model_version(self, mock_http): name = "model_1" version = "12" self.store.delete_model_version(name=name, version=version) self._verify_requests( mock_http, "model-versions/delete", "DELETE", DeleteModelVersion(name=name, version=version), ) @mock_http_request def test_get_model_version_details(self, mock_http): name = "model_11" version = "8" self.store.get_model_version(name=name, version=version) self._verify_requests(mock_http, "model-versions/get", "GET", GetModelVersion(name=name, version=version)) @mock_http_request def test_get_model_version_download_uri(self, mock_http): name = "model_11" version = "8" self.store.get_model_version_download_uri(name=name, version=version) self._verify_requests( mock_http, "model-versions/get-download-uri", "GET", GetModelVersionDownloadUri(name=name, version=version), ) @mock_http_request def test_search_model_versions(self, mock_http): self.store.search_model_versions(filter_string="name='model_12'") self._verify_requests(mock_http, "model-versions/search", "GET", SearchModelVersions(filter="name='model_12'")) @mock_http_request def test_set_model_version_tag(self, mock_http): name = "model_1" tag = ModelVersionTag(key="key", value="value") self.store.set_model_version_tag(name=name, version="1", tag=tag) self._verify_requests( mock_http, "model-versions/set-tag", "POST", SetModelVersionTag(name=name, version="1", key=tag.key, value=tag.value), ) @mock_http_request def test_delete_model_version_tag(self, mock_http): name = "model_1" self.store.delete_model_version_tag(name=name, version="1", key="key") self._verify_requests( mock_http, "model-versions/delete-tag", "DELETE", DeleteModelVersionTag(name=name, version="1", key="key"), )
class TestRestStore(unittest.TestCase): def setUp(self): self.creds = MlflowHostCreds('https://hello') self.store = RestStore(lambda: self.creds) def tearDown(self): pass def _args(self, host_creds, endpoint, method, json_body): res = { 'host_creds': host_creds, 'endpoint': "/api/2.0/preview/mlflow/%s" % endpoint, 'method': method } if method == "GET": res["params"] = json.loads(json_body) else: res["json"] = json.loads(json_body) return res def _verify_requests(self, http_request, endpoint, method, proto_message): print(http_request.call_args_list) json_body = message_to_json(proto_message) http_request.assert_any_call( **(self._args(self.creds, endpoint, method, json_body))) @mock.patch('mlflow.utils.rest_utils.http_request') def test_create_registered_model(self, mock_http): self.store.create_registered_model("model_1") self._verify_requests(mock_http, "registered-models/create", "POST", CreateRegisteredModel(name="model_1")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_update_registered_model_name(self, mock_http): rm = RegisteredModel("model_1") self.store.update_registered_model(registered_model=rm, new_name="model_2") self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(registered_model=rm.to_proto(), name="model_2")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_update_registered_model_description(self, mock_http): rm = RegisteredModel("model_1") self.store.update_registered_model(registered_model=rm, description="test model") self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(registered_model=rm.to_proto(), description="test model")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_update_registered_model_all(self, mock_http): rm = RegisteredModel("model_1") self.store.update_registered_model(registered_model=rm, new_name="model_3", description="rename and describe") self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(registered_model=rm.to_proto(), name="model_3", description="rename and describe")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_delete_registered_model(self, mock_http): rm = RegisteredModel("model_1") self.store.delete_registered_model(registered_model=rm) self._verify_requests( mock_http, "registered-models/delete", "DELETE", DeleteRegisteredModel(registered_model=rm.to_proto())) @mock.patch('mlflow.utils.rest_utils.http_request') def test_list_registered_model(self, mock_http): self.store.list_registered_models() self._verify_requests(mock_http, "registered-models/list", "GET", ListRegisteredModels()) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_registered_model_detailed(self, mock_http): rm = RegisteredModel("model_1") self.store.get_registered_model_details(registered_model=rm) self._verify_requests( mock_http, "registered-models/get-details", "POST", GetRegisteredModelDetails(registered_model=rm.to_proto())) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_latest_versions(self, mock_http): rm = RegisteredModel("model_1") self.store.get_latest_versions(registered_model=rm) self._verify_requests( mock_http, "registered-models/get-latest-versions", "POST", GetLatestVersions(registered_model=rm.to_proto())) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_latest_versions_with_stages(self, mock_http): rm = RegisteredModel("model_1") self.store.get_latest_versions(registered_model=rm, stages=["blaah"]) self._verify_requests( mock_http, "registered-models/get-latest-versions", "POST", GetLatestVersions(registered_model=rm.to_proto(), stages=["blaah"])) @mock.patch('mlflow.utils.rest_utils.http_request') def test_create_model_version(self, mock_http): run_id = uuid.uuid4().hex self.store.create_model_version("model_1", "path/to/source", run_id) self._verify_requests( mock_http, "model-versions/create", "POST", CreateModelVersion(name="model_1", source="path/to/source", run_id=run_id)) @mock.patch('mlflow.utils.rest_utils.http_request') def test_update_model_version_stage(self, mock_http): rm = RegisteredModel("model_1") mv = ModelVersion(rm, 5) self.store.update_model_version(model_version=mv, stage="prod") self._verify_requests( mock_http, "model-versions/update", "PATCH", UpdateModelVersion(model_version=mv.to_proto(), stage="prod")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_update_model_version_decription(self, mock_http): rm = RegisteredModel("model_1") mv = ModelVersion(rm, 5) self.store.update_model_version(model_version=mv, description="test model version") self._verify_requests( mock_http, "model-versions/update", "PATCH", UpdateModelVersion(model_version=mv.to_proto(), description="test model version")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_update_model_version_all(self, mock_http): rm = RegisteredModel("model_1") mv = ModelVersion(rm, 5) self.store.update_model_version(model_version=mv, stage="5%", description="A|B test") self._verify_requests( mock_http, "model-versions/update", "PATCH", UpdateModelVersion(model_version=mv.to_proto(), stage="5%", description="A|B test")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_delete_model_version(self, mock_http): rm = RegisteredModel("model_1") mv = ModelVersion(rm, 12) self.store.delete_model_version(model_version=mv) self._verify_requests(mock_http, "model-versions/delete", "DELETE", DeleteModelVersion(model_version=mv.to_proto())) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_model_version_details(self, mock_http): rm = RegisteredModel("model_11") mv = ModelVersion(rm, 8) self.store.get_model_version_details(model_version=mv) self._verify_requests( mock_http, "model-versions/get-details", "POST", GetModelVersionDetails(model_version=mv.to_proto())) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_model_version_download_uri(self, mock_http): rm = RegisteredModel("model_11") mv = ModelVersion(rm, 8) self.store.get_model_version_download_uri(model_version=mv) self._verify_requests( mock_http, "model-versions/get-download-uri", "POST", GetModelVersionDownloadUri(model_version=mv.to_proto())) @mock.patch('mlflow.utils.rest_utils.http_request') def test_search_model_versions(self, mock_http): self.store.search_model_versions(filter_string="name='model_12'") self._verify_requests(mock_http, "model-versions/search", "GET", SearchModelVersions(filter="name='model_12'")) @mock.patch('mlflow.utils.rest_utils.http_request') def test_get_model_version_stages(self, mock_http): rm = RegisteredModel("model_11") mv = ModelVersion(rm, 8) self.store.get_model_version_stages(model_version=mv) self._verify_requests( mock_http, "model-versions/get-stages", "POST", GetModelVersionStages(model_version=mv.to_proto()))