def test_mlflow_metrics_dataset_saved_and_logged(tmp_path, tracking_uri, data, prefix): """Check if MlflowMetricsDataSet can be saved in catalog when filepath is given, and if logged in mlflow. """ mlflow.set_tracking_uri(tracking_uri.as_uri()) mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri()) mlflow_metrics_dataset = MlflowMetricsDataSet(prefix=prefix) with mlflow.start_run(): run_id = mlflow.active_run().info.run_id mlflow_metrics_dataset.save(data) # Check if metrics where logged corectly in MLflow. assert_are_metrics_logged(data, mlflow_client, run_id, prefix) # Check if metrics are stored in catalog. catalog_metrics = MlflowMetricsDataSet( prefix=prefix, # Run id needs to be provided as there is no active run. run_id=run_id, ).load() assert len(catalog_metrics) == len(data) for k in catalog_metrics.keys(): data_key = k.split(".")[-1] if prefix is not None else k assert data[data_key] == catalog_metrics[k]
def test_mlflow_metrics_dataset_exists(tmp_path, tracking_uri, metrics3): """Check if MlflowMetricsDataSet is well identified as existing if it has already been saved. """ prefix = "test_metric" mlflow.set_tracking_uri(tracking_uri.as_uri()) mlflow_metrics_dataset = MlflowMetricsDataSet(prefix=prefix) # a mlflow run_id is automatically created mlflow_metrics_dataset.save(metrics3) assert mlflow_metrics_dataset.exists()
def test_mlflow_metrics_dataset_fails_with_invalid_metric( tmp_path, tracking_uri, metrics3): """Check if MlflowMetricsDataSet is well identified as not existingif it has never been saved. """ mlflow.set_tracking_uri(tracking_uri.as_uri()) mlflow_metrics_dataset = MlflowMetricsDataSet(prefix="test_metric") with pytest.raises(DataSetError, match="Unexpected metric value. Should be of type"): mlflow_metrics_dataset.save({ "metric1": 1 }) # key: value is not valid, you must specify {key: {value, step}}
def test_mlflow_metrics_dataset_saved_without_run_id(tmp_path, tracking_uri, metrics3): """Check if MlflowMetricsDataSet can be saved in catalog when filepath is given, and if logged in mlflow. """ prefix = "test_metric" mlflow.set_tracking_uri(tracking_uri.as_uri()) mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri()) mlflow_metrics_dataset = MlflowMetricsDataSet(prefix=prefix) # a mlflow run_id is automatically created mlflow_metrics_dataset.save(metrics3) run_id = mlflow.active_run().info.run_id assert_are_metrics_logged(metrics3, mlflow_client, run_id, prefix)
def test_mlflow_metrics_logging_deactivation(tracking_uri, metrics): mlflow_metrics_dataset = MlflowMetricsDataSet(prefix="hello") mlflow.set_tracking_uri(tracking_uri.as_uri()) mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri()) mlflow_metrics_dataset._logging_activated = False all_runs_id_beginning = set([ run.run_id for k in range(len(mlflow_client.list_experiments())) for run in mlflow_client.list_run_infos(experiment_id=f"{k}") ]) mlflow_metrics_dataset.save(metrics) all_runs_id_end = set([ run.run_id for k in range(len(mlflow_client.list_experiments())) for run in mlflow_client.list_run_infos(experiment_id=f"{k}") ]) assert all_runs_id_beginning == all_runs_id_end