Пример #1
0
def test_list_registered_models(mock_get_request_message, mock_model_registry_store):
    mock_get_request_message.return_value = ListRegisteredModels()
    rmds = [
        RegisteredModelDetailed(name="model_1", creation_timestamp=111,
                                last_updated_timestamp=222, description="Test model",
                                latest_versions=[]),
        RegisteredModelDetailed(name="model_2", creation_timestamp=111,
                                last_updated_timestamp=333, description="Another model",
                                latest_versions=[]),
    ]
    mock_model_registry_store.list_registered_models.return_value = rmds
    resp = _list_registered_models()
    args, _ = mock_model_registry_store.list_registered_models.call_args
    assert args == ()
    assert json.loads(resp.get_data()) == {"registered_models_detailed": jsonify(rmds)}
Пример #2
0
def test_get_registered_model_details(mock_store):
    mock_store.get_registered_model_details.return_value = RegisteredModelDetailed(
        "Model 1", "1263283747835", "1283168374623874", "I am a model", [
            _model_version_detailed("Model 1", 3, "None"),
            _model_version_detailed("Model 1", 2, "Staging"),
            _model_version_detailed("Model 1", 1, "Production")
        ])
    result = newModelRegistryClient().get_registered_model_details("Model 1")
    mock_store.get_registered_model_details.assert_called_once()
    assert result.name == "Model 1"
    assert len(result.latest_versions) == 3
Пример #3
0
def test_get_registered_model_details(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mock_get_request_message.return_value = GetRegisteredModelDetails(
        registered_model=rm.to_proto())
    rmd = RegisteredModelDetailed(name="model_1", creation_timestamp=111,
                                  last_updated_timestamp=222, description="Test model",
                                  latest_versions=[])
    mock_model_registry_store.get_registered_model_details.return_value = rmd
    resp = _get_registered_model_details()
    args, _ = mock_model_registry_store.get_registered_model_details.call_args
    assert args == (rm, )
    assert json.loads(resp.get_data()) == {"registered_model_detailed": jsonify(rmd)}
Пример #4
0
 def to_mlflow_detailed_entity(self):
     # SqlRegisteredModel has backref to all "model_versions". Filter latest for each stage.
     latest_versions = {}
     for mv in self.model_versions:
         stage = mv.current_stage
         if stage not in latest_versions or latest_versions[
                 stage].version < mv.version:
             latest_versions[stage] = mv
     return RegisteredModelDetailed(self.name, self.creation_time,
                                    self.last_updated_time,
                                    self.description, [
                                        mvd.to_mlflow_detailed_entity()
                                        for mvd in latest_versions.values()
                                    ])