def test_add_comparison_data(test_config_comparison: PlotCrossValidationConfig) -> None:
    test_config_comparison.epoch = 2
    metrics_df, root_folder = download_metrics(test_config_comparison)
    initial_metrics = pd.concat(list(metrics_df.values()))
    all_metrics, focus_splits = add_comparison_data(test_config_comparison, initial_metrics)
    focus_split = test_config_comparison.run_recovery_id
    comparison_split = test_config_comparison.comparison_run_recovery_ids[0]
    assert focus_splits == [focus_split]
    assert set(all_metrics.split) == {focus_split, comparison_split}
def test_add_comparison_data() -> None:
    fallback_run = get_most_recent_run_id(
        fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
    crossval_config = PlotCrossValidationConfig(
        run_recovery_id=fallback_run + "_0",
        epoch=1,
        comparison_run_recovery_ids=[fallback_run + "_1"],
        model_category=ModelCategory.Segmentation)
    crossval_config.epoch = 2
    metrics_df, root_folder = download_metrics(crossval_config)
    initial_metrics = pd.concat(list(metrics_df.values()))
    all_metrics, focus_splits = add_comparison_data(crossval_config,
                                                    initial_metrics)
    focus_split = crossval_config.run_recovery_id
    comparison_split = crossval_config.comparison_run_recovery_ids[0]
    assert focus_splits == [focus_split]
    assert set(all_metrics.split) == {focus_split, comparison_split}