Пример #1
0
def test_set_ground_truth_with_no_metadata():
    """set_ground_truth() must raise a PSDSEvalError with None metadata"""
    gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t")
    psds_eval = PSDSEval()
    with pytest.raises(PSDSEvalError, match="Audio metadata is required"):
        psds_eval.set_ground_truth(gt, None)
Пример #2
0
def test_valid_thresholds(x):
    """Test the PSDSEval with a range of valid threshold values"""
    assert PSDSEval(dtc_threshold=x)
    assert PSDSEval(cttc_threshold=x)
    assert PSDSEval(gtc_threshold=x)
Пример #3
0
def test_eval_class_with_no_ground_truth():
    """Ensure that PSDSEval raises a PSDSEvalError when GT is None"""
    metadata = pd.read_csv(os.path.join(DATADIR, "test.metadata"), sep="\t")
    with pytest.raises(PSDSEvalError,
                       match="The ground truth cannot be set without data"):
        PSDSEval(metadata=metadata, ground_truth=None)
Пример #4
0
def test_eval_class_with_no_metadata():
    """Ensure that PSDSEval raises a PSDSEvalError when metadata is None"""
    gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t")
    with pytest.raises(PSDSEvalError, match="Audio metadata is required"):
        PSDSEval(metadata=None, ground_truth=gt)
Пример #5
0
def tests_num_operating_points_without_any_operating_points():
    """Ensures that the eval class has no operating points when initialised"""
    psds_eval = PSDSEval()
    assert psds_eval.num_operating_points() == 0
Пример #6
0
def test_simple_area_under_curve():
    """Ensure that the area under a curve function produces the correct area"""
    x = np.array([0, 1, 2])
    y = np.array([1, 2, 3])
    auc = PSDSEval._auc(x, y)
    assert auc == pytest.approx(3.0), "The area calculation was incorrect"
Пример #7
0
def test_simple_area_under_curve_with_max():
    """Ensure area calculation is correct when a max_x value is specified"""
    x = np.array([0, 1, 2, 3, 4])
    y = np.array([1.1, 2.3, 3.5, 4.2, 5.5])
    auc = PSDSEval._auc(x, y, max_x=2)
    assert auc == pytest.approx(3.4), "The area calculation was incorrect"
Пример #8
0
    dtc_threshold = 0.5
    gtc_threshold = 0.5
    cttc_threshold = 0.3
    alpha_ct = 0.0
    alpha_st = 0.0
    max_efpr = 100

    # Load metadata and ground truth tables
    data_dir = os.path.join(os.path.dirname(__file__), "data")
    ground_truth_csv = os.path.join(data_dir, "dcase2019t4_gt.csv")
    metadata_csv = os.path.join(data_dir, "dcase2019t4_meta.csv")
    gt_table = pd.read_csv(ground_truth_csv, sep="\t")
    meta_table = pd.read_csv(metadata_csv, sep="\t")

    # Instantiate PSDSEval
    psds_eval = PSDSEval(dtc_threshold, gtc_threshold, cttc_threshold,
                         ground_truth=gt_table, metadata=meta_table)

    # Add the operating points, with the attached information
    for i, th in enumerate(np.arange(0.1, 1.1, 0.1)):
        csv_file = os.path.join(data_dir, f"baseline_{th:.1f}.csv")
        det_t = pd.read_csv(os.path.join(csv_file), sep="\t")
        info = {"name": f"Op {i + 1}", "threshold": th}
        psds_eval.add_operating_point(det_t, info=info)
        print(f"\rOperating point {i+1} added", end=" ")

    # Calculate the PSD-Score
    psds = psds_eval.psds(alpha_ct, alpha_st, max_efpr)
    print(f"\nPSD-Score: {psds.value:.5f}")

    # Plot the PSD-ROC
    plot_psd_roc(psds)
Пример #9
0
def test_full_dcase_validset():
    """Run PSDSEval on all the example data from DCASE"""
    det = pd.read_csv(join(DATADIR, "baseline_validation_AA_0.005.csv"),
                      sep="\t")
    gt = pd.read_csv(join(DATADIR, "baseline_validation_gt.csv"), sep="\t")
    metadata = pd.read_csv(join(DATADIR, "baseline_validation_metadata.csv"),
                           sep="\t")
    # Record the checksums of the incoming data
    meta_hash = pd.util.hash_pandas_object(metadata).values
    gt_hash = pd.util.hash_pandas_object(gt).values
    det_hash = pd.util.hash_pandas_object(det).values

    psds_eval = PSDSEval(dtc_threshold=0.5,
                         gtc_threshold=0.5,
                         cttc_threshold=0.3,
                         ground_truth=gt,
                         metadata=metadata)
    # matrix (n_class, n_class) last col/row is world (for FP)
    exp_counts = np.array(
        [[269, 9, 63, 41, 120, 13, 7, 18, 128, 2, 302],
         [5, 59, 4, 45, 29, 31, 35, 46, 86, 58,
          416], [54, 17, 129, 19, 105, 13, 14, 16, 82, 20, 585],
         [37, 43, 8, 164, 56, 9, 63, 63, 87, 7, 1100],
         [45, 10, 79, 73, 278, 7, 24, 51, 154, 22, 1480],
         [14, 22, 11, 24, 30, 41, 51, 26, 62, 43, 386],
         [3, 20, 12, 136, 96, 35, 87, 103, 97, 27, 840],
         [8, 41, 13, 119, 93, 48, 135, 127, 185, 32, 662],
         [89, 120, 74, 493, 825, 203, 403, 187, 966, 89, 1340],
         [0, 83, 1, 12, 58, 27, 46, 46, 120, 67, 390],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
    tpr = np.array([
        0.64047619, 0.61458333, 0.37829912, 0.28924162, 0.4877193, 0.63076923,
        0.92553191, 0.53586498, 0.55074116, 0.72826087
    ])
    fpr = np.array([
        93.08219178, 128.21917808, 180.30821918, 339.04109589, 456.16438356,
        118.97260274, 258.90410959, 204.04109589, 413.01369863, 120.20547945
    ])
    ctr = np.array([[
        0., 39.38051054, 275.66357376, 179.40010356, 525.07347382, 56.88295966,
        30.62928597, 78.76102107, 560.07837208, 8.75122456
    ],
                    [
                        36.54377132, 0, 29.23501705, 328.89394185,
                        211.95387364, 226.57138217, 255.80639922, 336.20269612,
                        628.55286666, 423.90774728
                    ],
                    [
                        412.06956854, 129.72560491, 0, 144.98744078,
                        801.24638326, 99.20193317, 106.8328511, 122.09468697,
                        625.73527074, 152.61835872
                    ],
                    [
                        375.77227974, 436.70832511, 81.24806048, 0,
                        568.73642339, 91.40406805, 639.82847632, 639.82847632,
                        883.57265777, 71.09205292
                    ],
                    [
                        201.60236547, 44.80052566, 353.92415271, 327.04383731,
                        0, 31.36036796, 107.52126158, 228.48268086,
                        689.92809516, 98.56115645
                    ],
                    [
                        100.2921207, 157.60190396, 78.80095198, 171.92934977,
                        214.91168722, 0, 365.34986827, 186.25679559,
                        444.15082025, 308.04008501
                    ],
                    [
                        13.91555321, 92.77035472, 55.66221283, 630.83841208,
                        445.29770265, 162.34812076, 0, 477.7673268,
                        449.93622038, 125.23997887
                    ],
                    [
                        23.16073227, 118.69875286, 37.63618993, 344.51589244,
                        269.24351258, 138.96439359, 390.83735697, 0,
                        535.59193363, 92.64292906
                    ],
                    [
                        122.13847545, 164.68109049, 101.55333914, 676.56481345,
                        1132.18249714, 278.58551142, 553.05399557,
                        256.62803269, 0., 122.13847545
                    ],
                    [
                        0, 382.88155531, 4.61303079, 55.35636944, 267.55578564,
                        124.55183125, 212.1994162, 212.1994162, 553.56369442, 0
                    ]])
    psds_eval.add_operating_point(det)
    assert np.all(psds_eval.operating_points.counts[0] == exp_counts)
    np.testing.assert_allclose(psds_eval.operating_points.tpr[0], tpr)
    np.testing.assert_allclose(psds_eval.operating_points.fpr[0], fpr)
    np.testing.assert_allclose(psds_eval.operating_points.ctr[0], ctr)
    psds1 = psds_eval.psds(0.0, 0.0, 100.0)
    # Check that all the psds metrics match
    assert psds1.value == pytest.approx(0.0044306914546640595), \
        "PSDS value was calculated incorrectly"
    # Check that the data has not been messed about with
    assert np.all(pd.util.hash_pandas_object(gt).values == gt_hash)
    assert np.all(pd.util.hash_pandas_object(metadata).values == meta_hash)
    assert np.all(pd.util.hash_pandas_object(det).values == det_hash)