def test_add_comparison_data(test_config_comparison: PlotCrossValidationConfig) -> None: test_config_comparison.epoch = 2 metrics_df, root_folder = download_metrics(test_config_comparison) initial_metrics = pd.concat(list(metrics_df.values())) all_metrics, focus_splits = add_comparison_data(test_config_comparison, initial_metrics) focus_split = test_config_comparison.run_recovery_id comparison_split = test_config_comparison.comparison_run_recovery_ids[0] assert focus_splits == [focus_split] assert set(all_metrics.split) == {focus_split, comparison_split}
def test_add_comparison_data() -> None: fallback_run = get_most_recent_run_id( fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN) crossval_config = PlotCrossValidationConfig( run_recovery_id=fallback_run + "_0", epoch=1, comparison_run_recovery_ids=[fallback_run + "_1"], model_category=ModelCategory.Segmentation) crossval_config.epoch = 2 metrics_df, root_folder = download_metrics(crossval_config) initial_metrics = pd.concat(list(metrics_df.values())) all_metrics, focus_splits = add_comparison_data(crossval_config, initial_metrics) focus_split = crossval_config.run_recovery_id comparison_split = crossval_config.comparison_run_recovery_ids[0] assert focus_splits == [focus_split] assert set(all_metrics.split) == {focus_split, comparison_split}