Exemple #1
def test_set_ground_truth_with_no_metadata():
    """set_ground_truth() must raise a PSDSEvalError with None metadata"""
    gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t")
    psds_eval = PSDSEval()
    with pytest.raises(PSDSEvalError, match="Audio metadata is required"):
        psds_eval.set_ground_truth(gt, None)
Exemple #2
def test_valid_thresholds(x):
    """Test the PSDSEval with a range of valid threshold values"""
    assert PSDSEval(dtc_threshold=x)
    assert PSDSEval(cttc_threshold=x)
    assert PSDSEval(gtc_threshold=x)
Exemple #3
def test_eval_class_with_no_ground_truth():
    """Ensure that PSDSEval raises a PSDSEvalError when GT is None"""
    metadata = pd.read_csv(os.path.join(DATADIR, "test.metadata"), sep="\t")
    with pytest.raises(PSDSEvalError,
                       match="The ground truth cannot be set without data"):
        PSDSEval(metadata=metadata, ground_truth=None)
Exemple #4
def test_eval_class_with_no_metadata():
    """Ensure that PSDSEval raises a PSDSEvalError when metadata is None"""
    gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t")
    with pytest.raises(PSDSEvalError, match="Audio metadata is required"):
        PSDSEval(metadata=None, ground_truth=gt)
Exemple #5
def tests_num_operating_points_without_any_operating_points():
    """Ensures that the eval class has no operating points when initialised"""
    psds_eval = PSDSEval()
    assert psds_eval.num_operating_points() == 0
Exemple #6
def test_simple_area_under_curve():
    """Ensure that the area under a curve function produces the correct area"""
    x = np.array([0, 1, 2])
    y = np.array([1, 2, 3])
    auc = PSDSEval._auc(x, y)
    assert auc == pytest.approx(3.0), "The area calculation was incorrect"
Exemple #7
def test_simple_area_under_curve_with_max():
    """Ensure area calculation is correct when a max_x value is specified"""
    x = np.array([0, 1, 2, 3, 4])
    y = np.array([1.1, 2.3, 3.5, 4.2, 5.5])
    auc = PSDSEval._auc(x, y, max_x=2)
    assert auc == pytest.approx(3.4), "The area calculation was incorrect"
Exemple #8
    dtc_threshold = 0.5
    gtc_threshold = 0.5
    cttc_threshold = 0.3
    alpha_ct = 0.0
    alpha_st = 0.0
    max_efpr = 100

    # Load metadata and ground truth tables
    data_dir = os.path.join(os.path.dirname(__file__), "data")
    ground_truth_csv = os.path.join(data_dir, "dcase2019t4_gt.csv")
    metadata_csv = os.path.join(data_dir, "dcase2019t4_meta.csv")
    gt_table = pd.read_csv(ground_truth_csv, sep="\t")
    meta_table = pd.read_csv(metadata_csv, sep="\t")

    # Instantiate PSDSEval
    psds_eval = PSDSEval(dtc_threshold, gtc_threshold, cttc_threshold,
                         ground_truth=gt_table, metadata=meta_table)

    # Add the operating points, with the attached information
    for i, th in enumerate(np.arange(0.1, 1.1, 0.1)):
        csv_file = os.path.join(data_dir, f"baseline_{th:.1f}.csv")
        det_t = pd.read_csv(os.path.join(csv_file), sep="\t")
        info = {"name": f"Op {i + 1}", "threshold": th}
        psds_eval.add_operating_point(det_t, info=info)
        print(f"\rOperating point {i+1} added", end=" ")

    # Calculate the PSD-Score
    psds = psds_eval.psds(alpha_ct, alpha_st, max_efpr)
    print(f"\nPSD-Score: {psds.value:.5f}")

    # Plot the PSD-ROC
def test_full_dcase_validset():
    """Run PSDSEval on all the example data from DCASE"""
    det = pd.read_csv(join(DATADIR, "baseline_validation_AA_0.005.csv"),
    gt = pd.read_csv(join(DATADIR, "baseline_validation_gt.csv"), sep="\t")
    metadata = pd.read_csv(join(DATADIR, "baseline_validation_metadata.csv"),
    # Record the checksums of the incoming data
    meta_hash = pd.util.hash_pandas_object(metadata).values
    gt_hash = pd.util.hash_pandas_object(gt).values
    det_hash = pd.util.hash_pandas_object(det).values

    psds_eval = PSDSEval(dtc_threshold=0.5,
    # matrix (n_class, n_class) last col/row is world (for FP)
    exp_counts = np.array(
        [[269, 9, 63, 41, 120, 13, 7, 18, 128, 2, 302],
         [5, 59, 4, 45, 29, 31, 35, 46, 86, 58,
          416], [54, 17, 129, 19, 105, 13, 14, 16, 82, 20, 585],
         [37, 43, 8, 164, 56, 9, 63, 63, 87, 7, 1100],
         [45, 10, 79, 73, 278, 7, 24, 51, 154, 22, 1480],
         [14, 22, 11, 24, 30, 41, 51, 26, 62, 43, 386],
         [3, 20, 12, 136, 96, 35, 87, 103, 97, 27, 840],
         [8, 41, 13, 119, 93, 48, 135, 127, 185, 32, 662],
         [89, 120, 74, 493, 825, 203, 403, 187, 966, 89, 1340],
         [0, 83, 1, 12, 58, 27, 46, 46, 120, 67, 390],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
    tpr = np.array([
        0.64047619, 0.61458333, 0.37829912, 0.28924162, 0.4877193, 0.63076923,
        0.92553191, 0.53586498, 0.55074116, 0.72826087
    fpr = np.array([
        93.08219178, 128.21917808, 180.30821918, 339.04109589, 456.16438356,
        118.97260274, 258.90410959, 204.04109589, 413.01369863, 120.20547945
    ctr = np.array([[
        0., 39.38051054, 275.66357376, 179.40010356, 525.07347382, 56.88295966,
        30.62928597, 78.76102107, 560.07837208, 8.75122456
                        36.54377132, 0, 29.23501705, 328.89394185,
                        211.95387364, 226.57138217, 255.80639922, 336.20269612,
                        628.55286666, 423.90774728
                        412.06956854, 129.72560491, 0, 144.98744078,
                        801.24638326, 99.20193317, 106.8328511, 122.09468697,
                        625.73527074, 152.61835872
                        375.77227974, 436.70832511, 81.24806048, 0,
                        568.73642339, 91.40406805, 639.82847632, 639.82847632,
                        883.57265777, 71.09205292
                        201.60236547, 44.80052566, 353.92415271, 327.04383731,
                        0, 31.36036796, 107.52126158, 228.48268086,
                        689.92809516, 98.56115645
                        100.2921207, 157.60190396, 78.80095198, 171.92934977,
                        214.91168722, 0, 365.34986827, 186.25679559,
                        444.15082025, 308.04008501
                        13.91555321, 92.77035472, 55.66221283, 630.83841208,
                        445.29770265, 162.34812076, 0, 477.7673268,
                        449.93622038, 125.23997887
                        23.16073227, 118.69875286, 37.63618993, 344.51589244,
                        269.24351258, 138.96439359, 390.83735697, 0,
                        535.59193363, 92.64292906
                        122.13847545, 164.68109049, 101.55333914, 676.56481345,
                        1132.18249714, 278.58551142, 553.05399557,
                        256.62803269, 0., 122.13847545
                        0, 382.88155531, 4.61303079, 55.35636944, 267.55578564,
                        124.55183125, 212.1994162, 212.1994162, 553.56369442, 0
    assert np.all(psds_eval.operating_points.counts[0] == exp_counts)
    np.testing.assert_allclose(psds_eval.operating_points.tpr[0], tpr)
    np.testing.assert_allclose(psds_eval.operating_points.fpr[0], fpr)
    np.testing.assert_allclose(psds_eval.operating_points.ctr[0], ctr)
    psds1 = psds_eval.psds(0.0, 0.0, 100.0)
    # Check that all the psds metrics match
    assert psds1.value == pytest.approx(0.0044306914546640595), \
        "PSDS value was calculated incorrectly"
    # Check that the data has not been messed about with
    assert np.all(pd.util.hash_pandas_object(gt).values == gt_hash)
    assert np.all(pd.util.hash_pandas_object(metadata).values == meta_hash)
    assert np.all(pd.util.hash_pandas_object(det).values == det_hash)