예제 #1
0
def _set_registered_model_tag():
    request_message = _get_request_message(SetRegisteredModelTag())
    tag = RegisteredModelTag(key=request_message.key, value=request_message.value)
    _get_model_registry_store().set_registered_model_tag(
        name=request_message.name,
        tag=tag)
    return _wrap_response(SetRegisteredModelTag.Response())
예제 #2
0
 def test_set_registered_model_tag(self, mock_http):
     name = "model_1"
     tag = RegisteredModelTag(key="key", value="value")
     self.store.set_registered_model_tag(name=name, tag=tag)
     self._verify_requests(
         mock_http, "registered-models/set-tag", "POST",
         SetRegisteredModelTag(name=name, key=tag.key, value=tag.value))
예제 #3
0
def test_set_registered_model_tag(mock_get_request_message,
                                  mock_model_registry_store):
    name = "model1"
    tag = RegisteredModelTag(key="some weird key", value="some value")
    mock_get_request_message.return_value = SetRegisteredModelTag(
        name=name, key=tag.key, value=tag.value)
    _set_registered_model_tag()
    _, args = mock_model_registry_store.set_registered_model_tag.call_args
    assert args == {"name": name, "tag": tag}
예제 #4
0
    def set_registered_model_tag(self, name, tag):
        """
        Set a tag for the registered model.

        :param name: Registered model name.
        :param tag: :py:class:`mlflow.entities.model_registry.RegisteredModelTag` instance to log.
        :return: None
        """
        req_body = message_to_json(SetRegisteredModelTag(name=name, key=tag.key, value=tag.value))
        self._call_endpoint(SetRegisteredModelTag, req_body)