def _update_registered_model(): request_message = _get_request_message(UpdateRegisteredModel()) name = request_message.name new_description = request_message.description registered_model = _get_model_registry_store().update_registered_model( name=name, description=new_description) response_message = UpdateRegisteredModel.Response(registered_model=registered_model.to_proto()) return _wrap_response(response_message)
def _update_registered_model(): request_message = _get_request_message(UpdateRegisteredModel()) new_name = None new_description = None if request_message.HasField("name"): new_name = request_message.name if request_message.HasField("description"): new_description = request_message.description registered_model = _get_model_registry_store().update_registered_model( RegisteredModel.from_proto(request_message.registered_model), new_name, new_description) response_message = UpdateRegisteredModel.Response( registered_model=registered_model.to_proto()) return _wrap_response(response_message)
def test_update_registered_model_description(self, mock_http): name = "model_1" description = "test model" self.store.update_registered_model(name=name, description=description) self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(name=name, description=description))
def test_update_registered_model_description(self, mock_http): rm = RegisteredModel("model_1") self.store.update_registered_model(registered_model=rm, description="test model") self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(registered_model=rm.to_proto(), description="test model"))
def test_update_registered_model_name(self, mock_http): rm = RegisteredModel("model_1") self.store.update_registered_model(registered_model=rm, new_name="model_2") self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(registered_model=rm.to_proto(), name="model_2"))
def test_update_registered_model_all(self, mock_http): rm = RegisteredModel("model_1") self.store.update_registered_model(registered_model=rm, new_name="model_3", description="rename and describe") self._verify_requests( mock_http, "registered-models/update", "PATCH", UpdateRegisteredModel(registered_model=rm.to_proto(), name="model_3", description="rename and describe"))
def test_update_registered_model(mock_get_request_message, mock_model_registry_store): name = "model_1" description = "Test model" mock_get_request_message.return_value = UpdateRegisteredModel(name=name, description=description) rm2 = RegisteredModel(name, description=description) mock_model_registry_store.update_registered_model.return_value = rm2 resp = _update_registered_model() _, args = mock_model_registry_store.update_registered_model.call_args assert args == {"name": name, "description": u"Test model"} assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
def test_update_registered_model(mock_get_request_message, mock_model_registry_store): rm1 = RegisteredModel("model_1") mock_get_request_message.return_value = UpdateRegisteredModel(registered_model=rm1.to_proto(), name="model_2", description="Test model") rm2 = RegisteredModel("model_2") mock_model_registry_store.update_registered_model.return_value = rm2 resp = _update_registered_model() args, _ = mock_model_registry_store.update_registered_model.call_args assert args == (rm1, u"model_2", u"Test model") assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
def update_registered_model(self, name, description): """ Update description of the registered model. :param name: Registered model name. :param description: New description. :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object. """ req_body = message_to_json( UpdateRegisteredModel(name=name, description=description)) response_proto = self._call_endpoint(UpdateRegisteredModel, req_body) return RegisteredModel.from_proto(response_proto.registered_model)
def update_registered_model(self, registered_model, new_name=None, description=None): """ Updates metadata for RegisteredModel entity. Either ``new_name`` or ``description`` should be non-None. Backend raises exception if a registered model with given name does not exist. :param registered_model: :py:class:`mlflow.entities.model_registry.RegisteredModel` object. :param new_name: (Optional) New proposed name for the registered model. :param description: (Optional) New description. :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object. """ req_body = message_to_json(UpdateRegisteredModel( registered_model=registered_model.to_proto(), name=new_name, description=description)) response_proto = self._call_endpoint(UpdateRegisteredModel, req_body) return RegisteredModel.from_proto(response_proto.registered_model)