def test_mlflow_metrics_dataset_saved_and_logged(tmp_path, tracking_uri, data,
                                                 prefix):
    """Check if MlflowMetricsDataSet can be saved in catalog when filepath is given,
    and if logged in mlflow.
    """
    mlflow.set_tracking_uri(tracking_uri.as_uri())
    mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri())
    mlflow_metrics_dataset = MlflowMetricsDataSet(prefix=prefix)

    with mlflow.start_run():
        run_id = mlflow.active_run().info.run_id
        mlflow_metrics_dataset.save(data)

        # Check if metrics where logged corectly in MLflow.
        assert_are_metrics_logged(data, mlflow_client, run_id, prefix)

    # Check if metrics are stored in catalog.
    catalog_metrics = MlflowMetricsDataSet(
        prefix=prefix,
        # Run id needs to be provided as there is no active run.
        run_id=run_id,
    ).load()

    assert len(catalog_metrics) == len(data)
    for k in catalog_metrics.keys():
        data_key = k.split(".")[-1] if prefix is not None else k
        assert data[data_key] == catalog_metrics[k]
def test_mlflow_metrics_dataset_exists(tmp_path, tracking_uri, metrics3):
    """Check if MlflowMetricsDataSet is well identified as
    existing if it has already been saved.
    """
    prefix = "test_metric"

    mlflow.set_tracking_uri(tracking_uri.as_uri())
    mlflow_metrics_dataset = MlflowMetricsDataSet(prefix=prefix)

    # a mlflow run_id is automatically created
    mlflow_metrics_dataset.save(metrics3)
    assert mlflow_metrics_dataset.exists()
def test_mlflow_metrics_dataset_fails_with_invalid_metric(
        tmp_path, tracking_uri, metrics3):
    """Check if MlflowMetricsDataSet is well identified as
    not existingif it has never been saved.
    """

    mlflow.set_tracking_uri(tracking_uri.as_uri())
    mlflow_metrics_dataset = MlflowMetricsDataSet(prefix="test_metric")

    with pytest.raises(DataSetError,
                       match="Unexpected metric value. Should be of type"):
        mlflow_metrics_dataset.save({
            "metric1": 1
        })  # key: value is not valid, you must specify {key: {value, step}}
def test_mlflow_metrics_dataset_saved_without_run_id(tmp_path, tracking_uri,
                                                     metrics3):
    """Check if MlflowMetricsDataSet can be saved in catalog when filepath is given,
    and if logged in mlflow.
    """
    prefix = "test_metric"

    mlflow.set_tracking_uri(tracking_uri.as_uri())
    mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri())
    mlflow_metrics_dataset = MlflowMetricsDataSet(prefix=prefix)

    # a mlflow run_id is automatically created
    mlflow_metrics_dataset.save(metrics3)
    run_id = mlflow.active_run().info.run_id

    assert_are_metrics_logged(metrics3, mlflow_client, run_id, prefix)
def test_mlflow_metrics_logging_deactivation(tracking_uri, metrics):
    mlflow_metrics_dataset = MlflowMetricsDataSet(prefix="hello")

    mlflow.set_tracking_uri(tracking_uri.as_uri())
    mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri())

    mlflow_metrics_dataset._logging_activated = False

    all_runs_id_beginning = set([
        run.run_id for k in range(len(mlflow_client.list_experiments()))
        for run in mlflow_client.list_run_infos(experiment_id=f"{k}")
    ])

    mlflow_metrics_dataset.save(metrics)

    all_runs_id_end = set([
        run.run_id for k in range(len(mlflow_client.list_experiments()))
        for run in mlflow_client.list_run_infos(experiment_id=f"{k}")
    ])

    assert all_runs_id_beginning == all_runs_id_end