Beispiel #1
0
def test_provide_saved_model_caching_handle_existing_same_dir(tmpdir):
    """If the model exists in the model register, and the path there is the
    same as output_dir, output_dir is returned"""
    model_config = {"sklearn.decomposition.PCA": {"svd_solver": "auto"}}
    data_config = get_random_data()
    output_dir = os.path.join(tmpdir, "model")
    registry_dir = os.path.join(tmpdir, "registry")
    machine = Machine(
        name="model-name", dataset=data_config, model=model_config, project_name="test"
    )
    builder = ModelBuilder(machine)
    builder.build(output_dir=output_dir, model_register_dir=registry_dir)
    assert builder.cached_model_path == output_dir

    # Saving to same output_dir as the one saved in the registry just returns the output_dir
    builder.build(output_dir=output_dir, model_register_dir=registry_dir)
    assert builder.cached_model_path == output_dir
Beispiel #2
0
def test_output_dir(tmpdir):
    """
    Test building of model will create subdirectories for model saving if needed.
    """
    model_config = {"sklearn.decomposition.PCA": {"svd_solver": "auto"}}
    data_config = get_random_data()
    output_dir = os.path.join(tmpdir, "some", "sub", "directories")
    machine = Machine(
        name="model-name", dataset=data_config, model=model_config, project_name="test"
    )
    builder = ModelBuilder(machine)
    model, machine_out = builder.build()
    machine_check(machine_out, False)

    builder._save_model(model=model, machine=machine_out, output_dir=output_dir)

    # Assert the model was saved at the location
    # Should be model file, and the metadata
    assert len(os.listdir(output_dir)) == 2
Beispiel #3
0
def test_provide_saved_model_caching_handle_existing_different_register(tmpdir):
    """If the model exists in the model register, but the output_dir is not where
    the model is, the model is copied to the new location, unless the new location
    already exists. If it does then return it"""
    model_config = {"sklearn.decomposition.PCA": {"svd_solver": "auto"}}
    data_config = get_random_data()
    output_dir1 = os.path.join(tmpdir, "model1")
    output_dir2 = os.path.join(tmpdir, "model2")

    registry_dir = os.path.join(tmpdir, "registry")
    machine = Machine(
        name="model-name", dataset=data_config, model=model_config, project_name="test"
    )
    builder = ModelBuilder(machine)
    builder.build(output_dir=output_dir1, model_register_dir=registry_dir)

    builder.build(output_dir=output_dir2, model_register_dir=registry_dir)
    assert builder.cached_model_path == output_dir2

    builder.build(output_dir=output_dir2, model_register_dir=registry_dir)
    assert builder.cached_model_path == output_dir2