Пример #1
0
def test_load_multi_model(multi_model_folder: str):
    _, models = load_settings(multi_model_folder)
    models.sort(key=lambda m: m.version)

    assert len(models) == 5
    for idx, model in enumerate(models):
        # Models get read in reverse
        assert model.version == f"v{idx}"
Пример #2
0
def test_load_settings_fallback(monkeypatch, sum_model_settings: ModelSettings,
                                model_folder: str):
    http_port = 5000
    monkeypatch.setenv(f"{ENV_PREFIX_SETTINGS}HTTP_PORT", str(http_port))

    settings_path = os.path.join(model_folder, DEFAULT_SETTINGS_FILENAME)
    os.remove(settings_path)

    settings, models = load_settings(model_folder)

    assert settings.http_port == http_port
    assert models[0].name == sum_model_settings.name
Пример #3
0
def test_load_model_settings_fallback(monkeypatch,
                                      sum_model_settings: ModelSettings,
                                      model_folder: str):
    monkeypatch.setenv(f"{ENV_PREFIX_MODEL_SETTINGS}NAME",
                       sum_model_settings.name)
    monkeypatch.setenv(f"{ENV_PREFIX_MODEL_SETTINGS}VERSION",
                       sum_model_settings.version)
    monkeypatch.setenv(
        f"{ENV_PREFIX_MODEL_SETTINGS}IMPLEMENTATION",
        get_import_path(sum_model_settings.implementation),  # type: ignore
    )

    model_settings_path = os.path.join(model_folder,
                                       DEFAULT_MODEL_SETTINGS_FILENAME)
    os.remove(model_settings_path)

    _, models = load_settings(model_folder)

    assert len(models) == 1
    assert models[0].name == sum_model_settings.name
    assert models[0].version == sum_model_settings.version
Пример #4
0
def test_load_model_settings(model_folder: str):
    _, models = load_settings(model_folder)

    model_settings = models[0]._settings
    assert model_settings.parameters
    assert model_settings.parameters.uri == str(model_folder)
Пример #5
0
def test_load_models(sum_model_settings: ModelSettings, model_folder: str):
    _, models = load_settings(model_folder)

    assert len(models) == 1
    assert models[0].name == sum_model_settings.name
    assert models[0].version == sum_model_settings.version