def test_get_model_version_download_uri(mock_get_request_message, mock_model_registry_store): name = "model1" version = "32" mock_get_request_message.return_value = GetModelVersionDownloadUri(name=name, version=version) mock_model_registry_store.get_model_version_download_uri.return_value = "some/download/path" resp = _get_model_version_download_uri() _, args = mock_model_registry_store.get_model_version_download_uri.call_args assert args == {"name": name, "version": version} assert json.loads(resp.get_data()) == {"artifact_uri": "some/download/path"}
def test_get_model_version_download_uri(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mv = ModelVersion(registered_model=rm, version=32) mock_get_request_message.return_value = GetModelVersionDownloadUri(model_version=mv.to_proto()) mock_model_registry_store.get_model_version_download_uri.return_value = "some/download/path" resp = _get_model_version_download_uri() args, _ = mock_model_registry_store.get_model_version_download_uri.call_args assert args == (mv, ) assert json.loads(resp.get_data()) == {"artifact_uri": "some/download/path"}