def test_mlflow_model_logger_logging_deactivation_is_bool(): mlflow_model_logger_dataset = MlflowModelLoggerDataSet( flavor="mlflow.sklearn") with pytest.raises(ValueError, match="_logging_activated must be a boolean"): mlflow_model_logger_dataset._logging_activated = "hello"
def test_flavor_does_not_exists(): with pytest.raises(DataSetError, match="mlflow.whoops module not found"): MlflowModelLoggerDataSet.from_config( name="whoops", config={ "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet", "flavor": "mlflow.whoops", }, )
def test_save_and_load_sklearn_flavor_without_run_id(tracking_uri, mlflow_client, linreg_model, initial_active_run): mlflow.set_tracking_uri(tracking_uri) # close all opened mlflow runs to avoid interference between tests while mlflow.active_run(): mlflow.end_run() artifact_path = "my_linreg" model_config = { "name": "linreg", "config": { "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet", "run_id": None, "artifact_path": artifact_path, "flavor": "mlflow.sklearn", }, } mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config) # if no initial active run, "_save" triggers the run opening if initial_active_run: mlflow.start_run() mlflow_model_ds.save(linreg_model) current_run_id = mlflow.active_run().info.run_id mlflow_client.list_artifacts(run_id=current_run_id)[0] artifact = mlflow_client.list_artifacts(run_id=current_run_id)[0] assert artifact.path == artifact_path # the run_id is still opened mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config) linreg_model_loaded = mlflow_model_ds.load() assert isinstance(linreg_model_loaded, LinearRegression) assert pytest.approx(linreg_model_loaded.predict([[1, 2]])[0], abs=10**(-14)) == 5 # load a second time after closing the active_run mlflow.end_run() model_config2 = model_config.copy() model_config2["config"]["run_id"] = current_run_id mlflow_model_ds2 = MlflowModelLoggerDataSet.from_config(**model_config2) linreg_model_loaded2 = mlflow_model_ds2.load() assert isinstance(linreg_model_loaded2, LinearRegression) assert (pytest.approx(linreg_model_loaded2.predict([[1, 2]])[0], abs=10**(-14)) == 5)
def test_save_sklearn_flavor_with_run_id_and_already_active_run(tracking_uri): """This test checks that saving a mlflow dataset must fail if a run_id is specified but is different from the mlflow.active_run() """ mlflow.set_tracking_uri(tracking_uri) # close all opened mlflow runs to avoid interference between tests while mlflow.active_run(): mlflow.end_run() mlflow.start_run() existing_run_id = mlflow.active_run().info.run_id mlflow.end_run() artifact_path = "my_linreg" model_config = { "name": "linreg", "config": { "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet", "run_id": existing_run_id, "artifact_path": artifact_path, "flavor": "mlflow.sklearn", }, } mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config) # if a run is active, it is impossible to log in another run with mlflow.start_run(): with pytest.raises( DataSetError, match= "'run_id' cannot be specified if there is an mlflow active run.", ): mlflow_model_ds.save(linreg_model)
def test_pyfunc_flavor_python_model_save_and_load( tmp_folder, tracking_uri, pipeline, dummy_catalog, ): kedro_pipeline_model = KedroPipelineModel( pipeline=pipeline, catalog=dummy_catalog, input_name="raw_data", ) artifacts = kedro_pipeline_model.extract_pipeline_artifacts(tmp_folder) model_config = { "name": "kedro_pipeline_model", "config": { "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet", "flavor": "mlflow.pyfunc", "pyfunc_workflow": "python_model", "artifact_path": "test_model", "save_args": { "artifacts": artifacts, "conda_env": { "python": "3.7.0", "dependencies": ["kedro==0.16.5"] }, }, }, } mlflow.set_tracking_uri(tracking_uri) mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config) mlflow_model_ds.save(kedro_pipeline_model) current_run_id = mlflow.active_run().info.run_id # close the run, create another dataset and reload # (emulate a new "kedro run" with the launch of the ) mlflow.end_run() model_config2 = model_config.copy() model_config2["config"]["run_id"] = current_run_id mlflow_model_ds2 = MlflowModelLoggerDataSet.from_config(**model_config2) loaded_model = mlflow_model_ds2.load() loaded_model.predict(pd.DataFrame( data=[1], columns=["a"])) == pd.DataFrame(data=[2], columns=["a"])
def test_pyfunc_flavor_wrong_pyfunc_workflow(tracking_uri): model_config = { "name": "kedro_pipeline_model", "config": { "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet", "flavor": "mlflow.pyfunc", "pyfunc_workflow": "wrong_workflow", "artifact_path": "test_model", }, } with pytest.raises( DataSetError, match= r"PyFunc models require specifying `pyfunc_workflow` \(set to either `python_model` or `loader_module`\)", ): MlflowModelLoggerDataSet.from_config(**model_config)
def test_save_and_load_sklearn_flavor_with_run_id(tracking_uri, mlflow_client, linreg_model, active_run_when_loading): mlflow.set_tracking_uri(tracking_uri) # close all opened mlflow runs to avoid interference between tests while mlflow.active_run(): mlflow.end_run() mlflow.start_run() existing_run_id = mlflow.active_run().info.run_id mlflow.end_run() artifact_path = "my_linreg" model_config = { "name": "linreg", "config": { "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet", "run_id": existing_run_id, "artifact_path": artifact_path, "flavor": "mlflow.sklearn", }, } mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config) # "_save" opens, log and close the specified run mlflow_model_ds.save(linreg_model) mlflow_client.list_artifacts(run_id=existing_run_id)[0] artifact = mlflow_client.list_artifacts(run_id=existing_run_id)[0] assert artifact.path == artifact_path if not active_run_when_loading: mlflow.end_run() mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config) linreg_model_loaded = mlflow_model_ds.load() assert isinstance(linreg_model_loaded, LinearRegression) assert pytest.approx(linreg_model_loaded.predict([[1, 2]])[0], abs=10**(-14)) == 5
def test_mlflow_model_logger_logging_deactivation(tracking_uri, linreg_model): mlflow_model_logger_dataset = MlflowModelLoggerDataSet( flavor="mlflow.sklearn") mlflow.set_tracking_uri(tracking_uri) mlflow_client = MlflowClient(tracking_uri=tracking_uri) mlflow_model_logger_dataset._logging_activated = False all_runs_id_beginning = set([ run.run_id for k in range(len(mlflow_client.list_experiments())) for run in mlflow_client.list_run_infos(experiment_id=f"{k}") ]) mlflow_model_logger_dataset.save(linreg_model) all_runs_id_end = set([ run.run_id for k in range(len(mlflow_client.list_experiments())) for run in mlflow_client.list_run_infos(experiment_id=f"{k}") ]) assert all_runs_id_beginning == all_runs_id_end
def test_load_without_run_id_nor_active_run(): mlflow.set_tracking_uri(tracking_uri) # close all opened mlflow runs to avoid interference between tests while mlflow.active_run(): mlflow.end_run() artifact_path = "my_linreg" model_config = { "name": "linreg", "config": { "type": "kedro_mlflow.io.models.MlflowModelLoggerDataSet", "run_id": None, "artifact_path": artifact_path, "flavor": "mlflow.sklearn", }, } mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config) with pytest.raises( DataSetError, match="To access the model_uri, you must either", ): mlflow_model_ds.load()