def test_multi_ops_multiple_times_sequentially(): """Run 3 times test_full_psds using delete_ops to reset the metric""" gt = pd.read_csv(join(DATADIR, "baseline_validation_gt.csv"), sep="\t") metadata = pd.read_csv(join(DATADIR, "baseline_validation_metadata.csv"), sep="\t") dets = [] dets.append( pd.read_csv(join(DATADIR, "baseline_validation_AA_0.005.csv"), sep="\t")) for k in range(5): dets.append(dets[0].sample(4500, random_state=7 * k)) print(dets[k + 1]) psds_eval = PSDSEval(dtc_threshold=0.5, gtc_threshold=0.5, cttc_threshold=0.3, ground_truth=gt, metadata=metadata) ref_psds_value = 0.07224283564515908 for k in range(3): for det_t in dets: psds_eval.add_operating_point(det_t) psds = psds_eval.psds(0.0, 0.0, 100) assert psds.value == pytest.approx(ref_psds_value), \ "PSDS was calculated incorrectly" psds_eval.clear_all_operating_points() assert psds_eval.num_operating_points() == 0 assert len(psds_eval.operating_points) == 0
def test_that_add_operating_point_added_a_point(): """Ensure add_operating_point adds an operating point correctly""" det = pd.read_csv(os.path.join(DATADIR, "test_1.det"), sep="\t") metadata = pd.read_csv(os.path.join(DATADIR, "test.metadata"), sep="\t") gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t") psds_eval = PSDSEval(metadata=metadata, ground_truth=gt) psds_eval.add_operating_point(det) assert psds_eval.num_operating_points() == 1 assert psds_eval.operating_points["id"][0] == \ "423089ce6d6554174881f69f9d0e57a8be9f5bc682dfce301462a8753aa6ec5f"
def test_add_operating_point_with_zero_detections(): """An error must not be raised when there are no detections""" det = pd.read_csv(os.path.join(DATADIR, "empty.det"), sep="\t") metadata = pd.read_csv(os.path.join(DATADIR, "test.metadata"), sep="\t") gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t") psds_eval = PSDSEval(metadata=metadata, ground_truth=gt) psds_eval.add_operating_point(det) assert psds_eval.num_operating_points() == 1 assert psds_eval.operating_points["id"][0] == \ "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
def test_that_add_operating_point_added_a_point(): """Ensure add_operating_point adds an operating point correctly""" det = pd.read_csv(os.path.join(DATADIR, "test_1.det"), sep="\t") metadata = pd.read_csv(os.path.join(DATADIR, "test.metadata"), sep="\t") gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t") psds_eval = PSDSEval(metadata=metadata, ground_truth=gt) psds_eval.add_operating_point(det) assert psds_eval.num_operating_points() == 1 assert psds_eval.operating_points["id"][0] == \ "6f504797195d2df3bae13e416b8bf96ca89ec4e4e4d031dadadd72e382640387"
def test_delete_ops(): """Perform deletion of ops""" metadata = pd.read_csv(os.path.join(DATADIR, "test.metadata"), sep="\t") det = pd.read_csv(os.path.join(DATADIR, "test_1.det"), sep="\t") det_2 = pd.read_csv(os.path.join(DATADIR, "test_1a.det"), sep="\t") gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t") psds_eval = PSDSEval(dtc_threshold=0.5, gtc_threshold=0.5, cttc_threshold=0.3, ground_truth=gt, metadata=metadata) assert psds_eval.operating_points.empty psds_eval.add_operating_point(det) psds_eval.add_operating_point(det_2) assert psds_eval.num_operating_points() == 2 psds_eval.clear_all_operating_points() assert psds_eval.operating_points.empty
def test_adding_shuffled_operating_points(): """Avoid the addition of the same operating point after shuffling""" det = pd.read_csv(os.path.join(DATADIR, "test_1.det"), sep="\t") metadata = pd.read_csv(os.path.join(DATADIR, "test.metadata"), sep="\t") gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t") psds_eval = PSDSEval(metadata=metadata, ground_truth=gt) psds_eval.add_operating_point(det) det_shuffled = det.copy(deep=True) det_shuffled = det_shuffled.sample(frac=1.).reset_index(drop=True) psds_eval.add_operating_point(det_shuffled) det_shuffled2 = det.copy(deep=True) det_shuffled2 = det_shuffled2[["onset", "event_label", "offset", "filename"]] psds_eval.add_operating_point(det_shuffled2) assert psds_eval.num_operating_points() == 1 assert psds_eval.operating_points["id"][0] == \ "423089ce6d6554174881f69f9d0e57a8be9f5bc682dfce301462a8753aa6ec5f"
def test_add_operating_points_with_overlapping_events(table_name, raise_error): """Detections with overlapping events must raise an error""" metadata = pd.read_csv(os.path.join(DATADIR, "test.metadata"), sep="\t") det = pd.read_csv(os.path.join(DATADIR, table_name), sep="\t") gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t") psds_eval = PSDSEval(dtc_threshold=0.5, gtc_threshold=0.5, cttc_threshold=0.3, ground_truth=gt, metadata=metadata) if raise_error: with pytest.raises( PSDSEvalError, match="The detection dataframe provided has intersecting " "events/labels for the same class."): psds_eval.add_operating_point(det) else: psds_eval.add_operating_point(det) assert psds_eval.num_operating_points() == 1
def test_add_same_operating_point_with_different_info(): """Check the use of conflicting info for the same operating point""" metadata = pd.read_csv(os.path.join(DATADIR, "test.metadata"), sep="\t") det1 = pd.read_csv(os.path.join(DATADIR, "test_1.det"), sep="\t") gt = pd.read_csv(os.path.join(DATADIR, "test_1.gt"), sep="\t") info1 = {"name": "test_1", "threshold1": 1} info2 = {"name": "test_1_2", "threshold2": 0} psds_eval = PSDSEval(dtc_threshold=0.5, gtc_threshold=0.5, cttc_threshold=0.3, ground_truth=gt, metadata=metadata) psds_eval.add_operating_point(det1, info=info1) psds_eval.add_operating_point(det1, info=info2) assert psds_eval.num_operating_points() == 1 assert psds_eval.operating_points.name[0] == "test_1", \ "The info name is not correctly reported." assert psds_eval.operating_points.threshold1[0] == 1, \ "The info threshold1 is not correctly reported." assert "threshold2" not in psds_eval.operating_points.columns, \ "The info of ignored operating point modified the operating " \ "points table."
def tests_num_operating_points_without_any_operating_points(): """Ensures that the eval class has no operating points when initialised""" psds_eval = PSDSEval() assert psds_eval.num_operating_points() == 0