Exemplo n.º 1
0
def test_list_registered_models(mock_store):
    mock_store.list_registered_models.return_value = PagedList(
        [RegisteredModel("Model 1"),
         RegisteredModel("Model 2")], "")
    result = newModelRegistryClient().list_registered_models()
    mock_store.list_registered_models.assert_called_once()
    assert len(result) == 2
Exemplo n.º 2
0
def test_search_registered_models(mock_store):
    mock_store.search_registered_models.return_value = PagedList(
        [RegisteredModel("Model 1"),
         RegisteredModel("Model 2")], "")
    result = newModelRegistryClient().search_registered_models(
        filter_string="test filter")
    mock_store.search_registered_models.assert_called_with(
        "test filter", 100, None, None)
    assert len(result) == 2
    assert result.token == ""

    result = newModelRegistryClient().search_registered_models(
        filter_string="another filter",
        max_results=12,
        order_by=["A", "B DESC"],
        page_token="next one")
    mock_store.search_registered_models.assert_called_with(
        "another filter", 12, ["A", "B DESC"], "next one")
    assert len(result) == 2
    assert result.token == ""

    mock_store.search_registered_models.return_value = PagedList([
        RegisteredModel("model A"),
        RegisteredModel("Model zz"),
        RegisteredModel("Model b")
    ], "page 2 token")
    result = newModelRegistryClient().search_registered_models(max_results=5)
    mock_store.search_registered_models.assert_called_with(None, 5, None, None)
    assert [rm.name for rm in result] == ["model A", "Model zz", "Model b"]
    assert result.token == "page 2 token"
Exemplo n.º 3
0
    def search_registered_models(self,
                                 filter_string=None,
                                 max_results=None,
                                 order_by=None,
                                 page_token=None):
        """
        Search for registered models in backend that satisfy the filter criteria.

        :param filter_string: Filter query string, defaults to searching all registered models.
        :param max_results: Maximum number of registered models desired.
        :param order_by: List of column names with ASC|DESC annotation, to be used for ordering
                         matching search results.
        :param page_token: Token specifying the next page of results. It should be obtained from
                            a ``search_registered_models`` call.
        :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
                that satisfy the search expressions. The pagination token for the next page can be
                obtained via the ``token`` attribute of the object.
        """
        req_body = message_to_json(
            SearchRegisteredModels(filter=filter_string,
                                   max_results=max_results,
                                   order_by=order_by,
                                   page_token=page_token))
        response_proto = self._call_endpoint(SearchRegisteredModels, req_body)
        registered_models = [
            RegisteredModel.from_proto(registered_model)
            for registered_model in response_proto.registered_models
        ]
        return PagedList(registered_models, response_proto.next_page_token)
Exemplo n.º 4
0
def test_rename_registered_model(mock_store):
    name = "Model 1"
    new_name = "New Name"
    mock_store.rename_registered_model.return_value = RegisteredModel(new_name)
    result = newModelRegistryClient().rename_registered_model(
        name=name, new_name=new_name)
    mock_store.rename_registered_model.assert_called_with(name=name,
                                                          new_name=new_name)
    assert result.name == "New Name"

    mock_store.rename_registered_model.return_value = RegisteredModel(
        "New Name 2")
    result = newModelRegistryClient().rename_registered_model(
        name=name, new_name="New Name 2")
    mock_store.rename_registered_model.assert_called_with(
        name=name, new_name="New Name 2")
    assert result.name == "New Name 2"
Exemplo n.º 5
0
    def get_registered_model(self, name):
        """
        Get registered model instance by name.

        :param name: Registered model name.
        :return: A single :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
        """
        req_body = message_to_json(GetRegisteredModel(name=name))
        response_proto = self._call_endpoint(GetRegisteredModel, req_body)
        return RegisteredModel.from_proto(response_proto.registered_model)
Exemplo n.º 6
0
def test_create_registered_model(mock_store):
    tags_dict = {"key": "value", "another key": "some other value"}
    tags = [RegisteredModelTag(key, value) for key, value in tags_dict.items()]
    mock_store.create_registered_model.return_value = RegisteredModel(
        "Model 1", tags=tags)
    result = newModelRegistryClient().create_registered_model(
        "Model 1", tags_dict)
    mock_store.create_registered_model.assert_called_once_with("Model 1", tags)
    assert result.name == "Model 1"
    assert result.tags == tags_dict
Exemplo n.º 7
0
def test_update_registered_model(mock_store):
    name = "Model 1"
    new_description = "New Description"
    new_description_2 = "New Description 2"
    mock_store.update_registered_model.return_value = RegisteredModel(
        name, description=new_description)

    result = newModelRegistryClient().update_registered_model(
        name=name, description=new_description)
    mock_store.update_registered_model.assert_called_with(
        name=name, description=new_description)
    assert result.description == new_description

    mock_store.update_registered_model.return_value = RegisteredModel(
        name, description=new_description_2)
    result = newModelRegistryClient().update_registered_model(
        name=name, description=new_description_2)
    mock_store.update_registered_model.assert_called_with(
        name=name, description="New Description 2")
    assert result.description == new_description_2
Exemplo n.º 8
0
    def update_registered_model(self, name, description):
        """
        Update description of the registered model.

        :param name: Registered model name.
        :param description: New description.
        :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
        """
        req_body = message_to_json(
            UpdateRegisteredModel(name=name, description=description))
        response_proto = self._call_endpoint(UpdateRegisteredModel, req_body)
        return RegisteredModel.from_proto(response_proto.registered_model)
Exemplo n.º 9
0
    def rename_registered_model(self, name, new_name):
        """
        Rename the registered model.

        :param name: Registered model name.
        :param new_name: New proposed name.
        :return: A single updated :py:class:`mlflow.entities.model_registry.RegisteredModel` object.
        """
        req_body = message_to_json(
            RenameRegisteredModel(name=name, new_name=new_name))
        response_proto = self._call_endpoint(RenameRegisteredModel, req_body)
        return RegisteredModel.from_proto(response_proto.registered_model)
Exemplo n.º 10
0
def test_rename_registered_model(mock_get_request_message,
                                 mock_model_registry_store):
    name = "model_1"
    new_name = "model_2"
    mock_get_request_message.return_value = RenameRegisteredModel(
        name=name, new_name=new_name)
    rm2 = RegisteredModel(new_name)
    mock_model_registry_store.rename_registered_model.return_value = rm2
    resp = _rename_registered_model()
    _, args = mock_model_registry_store.rename_registered_model.call_args
    assert args == {"name": name, "new_name": new_name}
    assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
Exemplo n.º 11
0
def test_update_registered_model(mock_get_request_message,
                                 mock_model_registry_store):
    name = "model_1"
    description = "Test model"
    mock_get_request_message.return_value = UpdateRegisteredModel(
        name=name, description=description)
    rm2 = RegisteredModel(name, description=description)
    mock_model_registry_store.update_registered_model.return_value = rm2
    resp = _update_registered_model()
    _, args = mock_model_registry_store.update_registered_model.call_args
    assert args == {"name": name, "description": u"Test model"}
    assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
Exemplo n.º 12
0
def test_get_registered_model(mock_get_request_message,
                              mock_model_registry_store):
    name = "model1"
    mock_get_request_message.return_value = GetRegisteredModel(name=name)
    rmd = RegisteredModel(name=name,
                          creation_timestamp=111,
                          last_updated_timestamp=222,
                          description="Test model",
                          latest_versions=[])
    mock_model_registry_store.get_registered_model.return_value = rmd
    resp = _get_registered_model()
    _, args = mock_model_registry_store.get_registered_model.call_args
    assert args == {"name": name}
    assert json.loads(resp.get_data()) == {"registered_model": jsonify(rmd)}
Exemplo n.º 13
0
 def to_mlflow_entity(self):
     # SqlRegisteredModel has backref to all "model_versions". Filter latest for each stage.
     latest_versions = {}
     for mv in self.model_versions:
         stage = mv.current_stage
         if stage != STAGE_DELETED_INTERNAL and (
                 stage not in latest_versions
                 or latest_versions[stage].version < mv.version):
             latest_versions[stage] = mv
     return RegisteredModel(
         self.name, self.creation_time, self.last_updated_time,
         self.description,
         [mvd.to_mlflow_entity() for mvd in latest_versions.values()],
         [tag.to_mlflow_entity() for tag in self.registered_model_tags])
Exemplo n.º 14
0
def test_list_registered_models(mock_get_request_message,
                                mock_model_registry_store):
    mock_get_request_message.return_value = ListRegisteredModels(
        max_results=50)
    rmds = PagedList([
        RegisteredModel(name="model_1",
                        creation_timestamp=111,
                        last_updated_timestamp=222,
                        description="Test model",
                        latest_versions=[]),
        RegisteredModel(name="model_2",
                        creation_timestamp=111,
                        last_updated_timestamp=333,
                        description="Another model",
                        latest_versions=[]),
    ], "next_pt")
    mock_model_registry_store.list_registered_models.return_value = rmds
    resp = _list_registered_models()
    args, _ = mock_model_registry_store.list_registered_models.call_args
    assert args == (50, '')
    assert json.loads(resp.get_data()) == {
        "next_page_token": "next_pt",
        "registered_models": jsonify(rmds)
    }
Exemplo n.º 15
0
    def create_registered_model(self, name, tags=None):
        """
        Create a new registered model in backend store.

        :param name: Name of the new model. This is expected to be unique in the backend store.
        :param tags: A list of :py:class:`mlflow.entities.model_registry.RegisteredModelTag`
                     instances associated with this registered model.
        :return: A single object of :py:class:`mlflow.entities.model_registry.RegisteredModel`
                 created in the backend.
        """
        proto_tags = [tag.to_proto() for tag in tags or []]
        req_body = message_to_json(
            CreateRegisteredModel(name=name, tags=proto_tags))
        response_proto = self._call_endpoint(CreateRegisteredModel, req_body)
        return RegisteredModel.from_proto(response_proto.registered_model)
Exemplo n.º 16
0
def test_register_model_with_non_runs_uri():
    create_model_patch = mock.patch.object(
        MlflowClient,
        "create_registered_model",
        return_value=RegisteredModel("Model 1"))
    create_version_patch = mock.patch.object(MlflowClient,
                                             "create_model_version",
                                             return_value=ModelVersion(
                                                 "Model 1",
                                                 "1",
                                                 creation_timestamp=123))
    with create_model_patch, create_version_patch:
        register_model("s3:/some/path/to/model", "Model 1")
        MlflowClient.create_registered_model.assert_called_once_with("Model 1")
        MlflowClient.create_model_version.assert_called_once_with(
            "Model 1", run_id=None, source="s3:/some/path/to/model")
Exemplo n.º 17
0
def test_create_registered_model(mock_get_request_message,
                                 mock_model_registry_store):
    tags = [
        RegisteredModelTag(key="key", value="value"),
        RegisteredModelTag(key="anotherKey", value="some other value")
    ]
    mock_get_request_message.return_value = CreateRegisteredModel(
        name="model_1", tags=[tag.to_proto() for tag in tags])
    rm = RegisteredModel("model_1", tags=tags)
    mock_model_registry_store.create_registered_model.return_value = rm
    resp = _create_registered_model()
    _, args = mock_model_registry_store.create_registered_model.call_args
    assert args["name"] == "model_1"
    assert {tag.key: tag.value
            for tag in args["tags"]} == {tag.key: tag.value
                                         for tag in tags}
    assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm)}
Exemplo n.º 18
0
def test_register_model_with_runs_uri():
    create_model_patch = mock.patch.object(
        MlflowClient,
        "create_registered_model",
        return_value=RegisteredModel("Model 1"))
    get_uri_patch = mock.patch(
        "mlflow.store.artifact.runs_artifact_repo.RunsArtifactRepository.get_underlying_uri",
        return_value="s3:/path/to/source")
    create_version_patch = mock.patch.object(MlflowClient,
                                             "create_model_version",
                                             return_value=ModelVersion(
                                                 "Model 1",
                                                 "1",
                                                 creation_timestamp=123))
    with get_uri_patch, create_model_patch, create_version_patch:
        register_model("runs:/run12345/path/to/model", "Model 1")
        MlflowClient.create_registered_model.assert_called_once_with("Model 1")
        MlflowClient.create_model_version.assert_called_once_with(
            "Model 1", "s3:/path/to/source", "run12345")
Exemplo n.º 19
0
    def list_registered_models(self, max_results, page_token):
        """
        List of all registered models.

        :param max_results: Maximum number of registered models desired.
        :param page_token: Token specifying the next page of results. It should be obtained from
                            a ``list_registered_models`` call.
        :return: A PagedList of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects
                that satisfy the search expressions. The pagination token for the next page can be
                obtained via the ``token`` attribute of the object.
        """
        req_body = message_to_json(
            ListRegisteredModels(page_token=page_token,
                                 max_results=max_results))
        response_proto = self._call_endpoint(ListRegisteredModels, req_body)
        return PagedList([
            RegisteredModel.from_proto(registered_model)
            for registered_model in response_proto.registered_models
        ], response_proto.next_page_token)
Exemplo n.º 20
0
def test_get_registered_model_details(mock_store):
    name = "Model 1"
    tags = [
        RegisteredModelTag("key", "value"),
        RegisteredModelTag("another key", "some other value")
    ]
    mock_store.get_registered_model.return_value = RegisteredModel(
        name,
        "1263283747835",
        "1283168374623874",
        "I am a model", [
            _model_version("Model 1", 3, "None"),
            _model_version("Model 1", 2, "Staging"),
            _model_version("Model 1", 1, "Production")
        ],
        tags=tags)
    result = newModelRegistryClient().get_registered_model(name)
    mock_store.get_registered_model.assert_called_once()
    assert result.name == name
    assert len(result.latest_versions) == 3
    assert result.tags == {tag.key: tag.value for tag in tags}
Exemplo n.º 21
0
def test_search_registered_models(mock_get_request_message,
                                  mock_model_registry_store):
    rmds = [
        RegisteredModel(name="model_1",
                        creation_timestamp=111,
                        last_updated_timestamp=222,
                        description="Test model",
                        latest_versions=[]),
        RegisteredModel(name="model_2",
                        creation_timestamp=111,
                        last_updated_timestamp=333,
                        description="Another model",
                        latest_versions=[]),
    ]
    mock_get_request_message.return_value = SearchRegisteredModels()
    mock_model_registry_store.search_registered_models.return_value = PagedList(
        rmds, None)
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {
        "filter_string": "",
        "max_results": 100,
        "order_by": [],
        "page_token": ""
    }
    assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)}

    mock_get_request_message.return_value = SearchRegisteredModels(
        filter="hello")
    mock_model_registry_store.search_registered_models.return_value = PagedList(
        rmds[:1], "tok")
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {
        "filter_string": "hello",
        "max_results": 100,
        "order_by": [],
        "page_token": ""
    }
    assert json.loads(resp.get_data()) == {
        "registered_models": jsonify(rmds[:1]),
        "next_page_token": "tok"
    }

    mock_get_request_message.return_value = SearchRegisteredModels(
        filter="hi", max_results=5)
    mock_model_registry_store.search_registered_models.return_value = PagedList(
        [rmds[0]], "tik")
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {
        "filter_string": "hi",
        "max_results": 5,
        "order_by": [],
        "page_token": ""
    }
    assert json.loads(resp.get_data()) == {
        "registered_models": jsonify([rmds[0]]),
        "next_page_token": "tik"
    }

    mock_get_request_message.return_value = SearchRegisteredModels(
        filter="hey",
        max_results=500,
        order_by=["a", "B desc"],
        page_token="prev")
    mock_model_registry_store.search_registered_models.return_value = PagedList(
        rmds, "DONE")
    resp = _search_registered_models()
    _, args = mock_model_registry_store.search_registered_models.call_args
    assert args == {
        "filter_string": "hey",
        "max_results": 500,
        "order_by": ["a", "B desc"],
        "page_token": "prev"
    }
    assert json.loads(resp.get_data()) == {
        "registered_models": jsonify(rmds),
        "next_page_token": "DONE"
    }