コード例 #1
0
def _update_registered_model():
    request_message = _get_request_message(UpdateRegisteredModel())
    name = request_message.name
    new_description = request_message.description
    registered_model = _get_model_registry_store().update_registered_model(
        name=name, description=new_description)
    response_message = UpdateRegisteredModel.Response(registered_model=registered_model.to_proto())
    return _wrap_response(response_message)
コード例 #2
0
def _update_registered_model():
    request_message = _get_request_message(UpdateRegisteredModel())
    new_name = None
    new_description = None
    if request_message.HasField("name"):
        new_name = request_message.name
    if request_message.HasField("description"):
        new_description = request_message.description
    registered_model = _get_model_registry_store().update_registered_model(
        RegisteredModel.from_proto(request_message.registered_model), new_name,
        new_description)
    response_message = UpdateRegisteredModel.Response(
        registered_model=registered_model.to_proto())
    return _wrap_response(response_message)
コード例 #3
0
 def test_update_registered_model_description(self, mock_http):
     name = "model_1"
     description = "test model"
     self.store.update_registered_model(name=name, description=description)
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(name=name, description=description))
 def test_update_registered_model_description(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.update_registered_model(registered_model=rm,
                                        description="test model")
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(registered_model=rm.to_proto(),
                               description="test model"))
 def test_update_registered_model_name(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.update_registered_model(registered_model=rm,
                                        new_name="model_2")
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(registered_model=rm.to_proto(),
                               name="model_2"))
 def test_update_registered_model_all(self, mock_http):
     rm = RegisteredModel("model_1")
     self.store.update_registered_model(registered_model=rm,
                                        new_name="model_3",
                                        description="rename and describe")
     self._verify_requests(
         mock_http, "registered-models/update", "PATCH",
         UpdateRegisteredModel(registered_model=rm.to_proto(),
                               name="model_3",
                               description="rename and describe"))
コード例 #7
0
ファイル: test_handlers.py プロジェクト: zirubak/mlflow
def test_update_registered_model(mock_get_request_message, mock_model_registry_store):
    name = "model_1"
    description = "Test model"
    mock_get_request_message.return_value = UpdateRegisteredModel(name=name,
                                                                  description=description)
    rm2 = RegisteredModel(name, description=description)
    mock_model_registry_store.update_registered_model.return_value = rm2
    resp = _update_registered_model()
    _, args = mock_model_registry_store.update_registered_model.call_args
    assert args == {"name": name, "description": u"Test model"}
    assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
コード例 #8
0
def test_update_registered_model(mock_get_request_message, mock_model_registry_store):
    rm1 = RegisteredModel("model_1")
    mock_get_request_message.return_value = UpdateRegisteredModel(registered_model=rm1.to_proto(),
                                                                  name="model_2",
                                                                  description="Test model")
    rm2 = RegisteredModel("model_2")
    mock_model_registry_store.update_registered_model.return_value = rm2
    resp = _update_registered_model()
    args, _ = mock_model_registry_store.update_registered_model.call_args
    assert args == (rm1, u"model_2", u"Test model")
    assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
コード例 #9
0
    def update_registered_model(self, name, description):
        """
        Update description of the registered model.

        :param name: Registered model name.
        :param description: New description.
        :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
        """
        req_body = message_to_json(
            UpdateRegisteredModel(name=name, description=description))
        response_proto = self._call_endpoint(UpdateRegisteredModel, req_body)
        return RegisteredModel.from_proto(response_proto.registered_model)
    def update_registered_model(self, registered_model, new_name=None, description=None):
        """
        Updates metadata for RegisteredModel entity. Either ``new_name`` or ``description`` should
        be non-None. Backend raises exception if a registered model with given name does not exist.

        :param registered_model: :py:class:`mlflow.entities.model_registry.RegisteredModel` object.

        :param new_name: (Optional) New proposed name for the registered model.
        :param description: (Optional) New description.

        :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
        """
        req_body = message_to_json(UpdateRegisteredModel(
            registered_model=registered_model.to_proto(), name=new_name, description=description))
        response_proto = self._call_endpoint(UpdateRegisteredModel, req_body)
        return RegisteredModel.from_proto(response_proto.registered_model)