コード例 #1
0
def test_mlflow_model_logger_logging_deactivation_is_bool():
    mlflow_model_logger_dataset = MlflowModelLoggerDataSet(
        flavor="mlflow.sklearn")

    with pytest.raises(ValueError,
                       match="_logging_activated must be a boolean"):
        mlflow_model_logger_dataset._logging_activated = "hello"
コード例 #2
0
def test_flavor_does_not_exists():
    with pytest.raises(DataSetError, match="mlflow.whoops module not found"):
        MlflowModelLoggerDataSet.from_config(
            name="whoops",
            config={
                "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet",
                "flavor": "mlflow.whoops",
            },
        )
コード例 #3
0
def test_save_and_load_sklearn_flavor_without_run_id(tracking_uri,
                                                     mlflow_client,
                                                     linreg_model,
                                                     initial_active_run):

    mlflow.set_tracking_uri(tracking_uri)
    # close all opened mlflow runs to avoid interference between tests
    while mlflow.active_run():
        mlflow.end_run()

    artifact_path = "my_linreg"
    model_config = {
        "name": "linreg",
        "config": {
            "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet",
            "run_id": None,
            "artifact_path": artifact_path,
            "flavor": "mlflow.sklearn",
        },
    }
    mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config)

    # if no initial active run, "_save" triggers the run opening
    if initial_active_run:
        mlflow.start_run()
    mlflow_model_ds.save(linreg_model)
    current_run_id = mlflow.active_run().info.run_id

    mlflow_client.list_artifacts(run_id=current_run_id)[0]
    artifact = mlflow_client.list_artifacts(run_id=current_run_id)[0]
    assert artifact.path == artifact_path

    # the run_id is still opened
    mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config)
    linreg_model_loaded = mlflow_model_ds.load()
    assert isinstance(linreg_model_loaded, LinearRegression)
    assert pytest.approx(linreg_model_loaded.predict([[1, 2]])[0],
                         abs=10**(-14)) == 5

    # load a second time after closing the active_run
    mlflow.end_run()
    model_config2 = model_config.copy()
    model_config2["config"]["run_id"] = current_run_id
    mlflow_model_ds2 = MlflowModelLoggerDataSet.from_config(**model_config2)
    linreg_model_loaded2 = mlflow_model_ds2.load()

    assert isinstance(linreg_model_loaded2, LinearRegression)
    assert (pytest.approx(linreg_model_loaded2.predict([[1, 2]])[0],
                          abs=10**(-14)) == 5)
コード例 #4
0
def test_save_sklearn_flavor_with_run_id_and_already_active_run(tracking_uri):
    """This test checks that saving a mlflow dataset must fail
    if a run_id is specified but is different from the
    mlflow.active_run()

    """
    mlflow.set_tracking_uri(tracking_uri)
    # close all opened mlflow runs to avoid interference between tests
    while mlflow.active_run():
        mlflow.end_run()

    mlflow.start_run()
    existing_run_id = mlflow.active_run().info.run_id
    mlflow.end_run()

    artifact_path = "my_linreg"
    model_config = {
        "name": "linreg",
        "config": {
            "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet",
            "run_id": existing_run_id,
            "artifact_path": artifact_path,
            "flavor": "mlflow.sklearn",
        },
    }
    mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config)

    # if a run is active, it is impossible to log in another run
    with mlflow.start_run():
        with pytest.raises(
                DataSetError,
                match=
                "'run_id' cannot be specified if there is an mlflow active run.",
        ):
            mlflow_model_ds.save(linreg_model)
コード例 #5
0
def test_pyfunc_flavor_python_model_save_and_load(
    tmp_folder,
    tracking_uri,
    pipeline,
    dummy_catalog,
):

    kedro_pipeline_model = KedroPipelineModel(
        pipeline=pipeline,
        catalog=dummy_catalog,
        input_name="raw_data",
    )
    artifacts = kedro_pipeline_model.extract_pipeline_artifacts(tmp_folder)

    model_config = {
        "name": "kedro_pipeline_model",
        "config": {
            "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet",
            "flavor": "mlflow.pyfunc",
            "pyfunc_workflow": "python_model",
            "artifact_path": "test_model",
            "save_args": {
                "artifacts": artifacts,
                "conda_env": {
                    "python": "3.7.0",
                    "dependencies": ["kedro==0.16.5"]
                },
            },
        },
    }

    mlflow.set_tracking_uri(tracking_uri)
    mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config)
    mlflow_model_ds.save(kedro_pipeline_model)
    current_run_id = mlflow.active_run().info.run_id

    # close the run, create another dataset and reload
    # (emulate a new "kedro run" with the launch of the )
    mlflow.end_run()
    model_config2 = model_config.copy()
    model_config2["config"]["run_id"] = current_run_id
    mlflow_model_ds2 = MlflowModelLoggerDataSet.from_config(**model_config2)

    loaded_model = mlflow_model_ds2.load()

    loaded_model.predict(pd.DataFrame(
        data=[1], columns=["a"])) == pd.DataFrame(data=[2], columns=["a"])
コード例 #6
0
def test_pyfunc_flavor_wrong_pyfunc_workflow(tracking_uri):

    model_config = {
        "name": "kedro_pipeline_model",
        "config": {
            "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet",
            "flavor": "mlflow.pyfunc",
            "pyfunc_workflow": "wrong_workflow",
            "artifact_path": "test_model",
        },
    }
    with pytest.raises(
            DataSetError,
            match=
            r"PyFunc models require specifying `pyfunc_workflow` \(set to either `python_model` or `loader_module`\)",
    ):
        MlflowModelLoggerDataSet.from_config(**model_config)
コード例 #7
0
def test_save_and_load_sklearn_flavor_with_run_id(tracking_uri, mlflow_client,
                                                  linreg_model,
                                                  active_run_when_loading):

    mlflow.set_tracking_uri(tracking_uri)
    # close all opened mlflow runs to avoid interference between tests
    while mlflow.active_run():
        mlflow.end_run()

    mlflow.start_run()
    existing_run_id = mlflow.active_run().info.run_id
    mlflow.end_run()

    artifact_path = "my_linreg"
    model_config = {
        "name": "linreg",
        "config": {
            "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet",
            "run_id": existing_run_id,
            "artifact_path": artifact_path,
            "flavor": "mlflow.sklearn",
        },
    }
    mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config)

    # "_save" opens, log and close the specified run
    mlflow_model_ds.save(linreg_model)

    mlflow_client.list_artifacts(run_id=existing_run_id)[0]
    artifact = mlflow_client.list_artifacts(run_id=existing_run_id)[0]
    assert artifact.path == artifact_path

    if not active_run_when_loading:
        mlflow.end_run()

    mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config)
    linreg_model_loaded = mlflow_model_ds.load()
    assert isinstance(linreg_model_loaded, LinearRegression)
    assert pytest.approx(linreg_model_loaded.predict([[1, 2]])[0],
                         abs=10**(-14)) == 5
コード例 #8
0
def test_mlflow_model_logger_logging_deactivation(tracking_uri, linreg_model):
    mlflow_model_logger_dataset = MlflowModelLoggerDataSet(
        flavor="mlflow.sklearn")

    mlflow.set_tracking_uri(tracking_uri)
    mlflow_client = MlflowClient(tracking_uri=tracking_uri)

    mlflow_model_logger_dataset._logging_activated = False

    all_runs_id_beginning = set([
        run.run_id for k in range(len(mlflow_client.list_experiments()))
        for run in mlflow_client.list_run_infos(experiment_id=f"{k}")
    ])

    mlflow_model_logger_dataset.save(linreg_model)

    all_runs_id_end = set([
        run.run_id for k in range(len(mlflow_client.list_experiments()))
        for run in mlflow_client.list_run_infos(experiment_id=f"{k}")
    ])

    assert all_runs_id_beginning == all_runs_id_end
コード例 #9
0
def test_load_without_run_id_nor_active_run():
    mlflow.set_tracking_uri(tracking_uri)
    # close all opened mlflow runs to avoid interference between tests
    while mlflow.active_run():
        mlflow.end_run()

    artifact_path = "my_linreg"
    model_config = {
        "name": "linreg",
        "config": {
            "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet",
            "run_id": None,
            "artifact_path": artifact_path,
            "flavor": "mlflow.sklearn",
        },
    }
    mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config)

    with pytest.raises(
            DataSetError,
            match="To access the model_uri, you must either",
    ):
        mlflow_model_ds.load()