def test_extract_outliers() -> None:
    """
    Test that extract_outliers correctly returns the DataFrame rows where Dice < mean - outlier_range * std
    """
    test_df = pd.DataFrame({"Dice": range(10)})

    # check the outliers are expected 0, 1 and 2 deviations less than the mean
    assert list(range(5)) == list(extract_outliers(test_df, 0).Dice.values)
    assert list(range(2)) == list(extract_outliers(test_df, 1).Dice.values)
    assert list() == list(extract_outliers(test_df, 2).Dice.values)
def display_outliers(df: pd.DataFrame, outlier_range: float, metric_name: str,
                     high_values_are_good: bool) -> pd.DataFrame:
    """
    Prints a dataframe that contains the worst patients by the given metric.
    "Worst" is determined by having a metric value which is less than the outlier_range
    standard deviation from the mean.
    """
    return extract_outliers(
        df, outlier_range, metric_name,
        OutlierType.LOW if high_values_are_good else OutlierType.HIGH)
def test_extract_outliers_higher() -> None:
    """
    Test that extract_outliers correctly returns the DataFrame rows where
    Hausdorff distance > mean + outlier_range * std
    """
    test_df = pd.DataFrame({"Hausdorff": range(10)})
    assert list(range(5, 10, 1)) == list(
        extract_outliers(test_df,
                         0,
                         "Hausdorff",
                         outlier_type=OutlierType.HIGH).Hausdorff.values)
    assert list(range(8, 10, 1)) == list(
        extract_outliers(test_df,
                         1,
                         "Hausdorff",
                         outlier_type=OutlierType.HIGH).Hausdorff.values)
    assert list() == list(
        extract_outliers(test_df,
                         2,
                         "Hausdorff",
                         outlier_type=OutlierType.HIGH).Hausdorff.values)
def save_outliers(config: PlotCrossValidationConfig,
                  dataset_split_metrics: Dict[ModelExecutionMode,
                                              pd.DataFrame],
                  root: Path) -> Dict[ModelExecutionMode, Path]:
    """
    Given the dataframe for the downloaded metrics identifies outliers (score < mean - 3sd) across the splits
    and saves them in a file outlier.csv in the provided root.
    :param config: PlotCrossValidationConfig
    :param dataset_split_metrics: Mapping between model execution mode and a dataframe containing all metrics for it
    :param root: Root directory to the results for Train/Test and Val datasets
    :return: Dictionary of mode and file path.
    """
    stats_columns = ['count', 'mean', 'min', 'max']
    outliers_paths = {}
    for mode, df in dataset_split_metrics.items():
        outliers_std = root / "{}_outliers.txt".format(mode.value)
        with open(outliers_std, 'w') as f:
            # to make sure no columns or rows are truncated
            with DEFAULT_PD_DISPLAY_CONTEXT:
                for metric_type, metric_type_metadata in get_available_metrics(
                        df).items():
                    outliers = extract_outliers(
                        df=df,
                        outlier_range=config.outlier_range,
                        outlier_type=metric_type_metadata["outlier_type"]
                    ).drop([COL_SPLIT], axis=1)

                    f.write(f"\n\n=== METRIC: {metric_type} ===\n\n")
                    if len(outliers) > 0:
                        # If running inside institution there may be no CSV_SERIES_HEADER or CSV_INSTITUTION_HEADER
                        # columns
                        groupby_columns = [
                            MetricsFileColumns.Patient.value,
                            MetricsFileColumns.Structure.value
                        ]
                        if CSV_SERIES_HEADER in outliers.columns:
                            groupby_columns.append(CSV_SERIES_HEADER)
                        if CSV_INSTITUTION_HEADER in outliers.columns:
                            groupby_columns.append(CSV_INSTITUTION_HEADER)
                        outliers_summary = str(
                            outliers.groupby(groupby_columns).describe()
                            [metric_type][stats_columns].sort_values(
                                stats_columns, ascending=False))
                        f.write(outliers_summary)
                    else:
                        f.write("No outliers found")

        print("Saved outliers to: {}".format(outliers_std))
        outliers_paths[mode] = outliers_std

    return outliers_paths