def test_create_registered_model(mock_get_request_message, mock_model_registry_store): mock_get_request_message.return_value = CreateRegisteredModel(name="model_1") rm = RegisteredModel("model_1") mock_model_registry_store.create_registered_model.return_value = rm resp = _create_registered_model() _, args = mock_model_registry_store.create_registered_model.call_args assert args == {"name": "model_1"} assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm)}
def test_create_registered_model(mock_get_request_message, mock_model_registry_store): tags = [ RegisteredModelTag(key="key", value="value"), RegisteredModelTag(key="anotherKey", value="some other value"), ] mock_get_request_message.return_value = CreateRegisteredModel( name="model_1", tags=[tag.to_proto() for tag in tags] ) rm = RegisteredModel("model_1", tags=tags) mock_model_registry_store.create_registered_model.return_value = rm resp = _create_registered_model() _, args = mock_model_registry_store.create_registered_model.call_args assert args["name"] == "model_1" assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags} assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm)}