def _set_model_version_tag(): request_message = _get_request_message(SetModelVersionTag()) tag = ModelVersionTag(key=request_message.key, value=request_message.value) _get_model_registry_store().set_model_version_tag( name=request_message.name, version=request_message.version, tag=tag) return _wrap_response(SetModelVersionTag.Response())
def test_set_model_version_tag(self, mock_http): name = "model_1" tag = ModelVersionTag(key="key", value="value") self.store.set_model_version_tag(name=name, version="1", tag=tag) self._verify_requests( mock_http, "model-versions/set-tag", "POST", SetModelVersionTag(name=name, version="1", key=tag.key, value=tag.value))
def test_set_model_version_tag(mock_get_request_message, mock_model_registry_store): name = "model1" version = "1" tag = ModelVersionTag(key="some weird key", value="some value") mock_get_request_message.return_value = SetModelVersionTag( name=name, version=version, key=tag.key, value=tag.value ) _set_model_version_tag() _, args = mock_model_registry_store.set_model_version_tag.call_args assert args == {"name": name, "version": version, "tag": tag}
def set_model_version_tag(self, name, version, tag): """ Set a tag for the model version. :param name: Registered model name. :param version: Registered model version. :param tag: :py:class:`mlflow.entities.model_registry.ModelVersionTag` instance to log. :return: None """ req_body = message_to_json(SetModelVersionTag(name=name, version=version, key=tag.key, value=tag.value)) self._call_endpoint(SetModelVersionTag, req_body)