Esempio n. 1
0
    def test_precision_recall(self):
        Iref = [[1, 2, 3], [5, 6], [], []]
        Inew = [[1, 2], [6, 7], [1], []]

        lims_ref = np.cumsum([0] + [len(x) for x in Iref])
        Iref = np.hstack(Iref)
        lims_new = np.cumsum([0] + [len(x) for x in Inew])
        Inew = np.hstack(Inew)

        precision, recall = evaluation.range_PR(lims_ref, Iref, lims_new, Inew)
        print(precision, recall)

        self.assertEqual(precision, 0.6)
        self.assertEqual(recall, 0.6)
Esempio n. 2
0
    def test_PR_multiple(self):
        metric = faiss.METRIC_L2
        ds = datasets.SyntheticDataset(32, 1000, 1000, 10)
        xq = ds.get_queries()
        xb = ds.get_database()

        # good for ~10k results
        threshold = 15

        index = faiss.IndexFlat(32, metric)
        index.add(xb)
        ref_lims, ref_D, ref_I = index.range_search(xq, threshold)

        # now make a slightly suboptimal index
        index2 = faiss.index_factory(32, "PCA16,Flat")
        index2.train(ds.get_train())
        index2.add(xb)

        # PCA reduces distances so will have more results
        new_lims, new_D, new_I = index2.range_search(xq, threshold)

        all_thr = np.array([5.0, 10.0, 12.0, 15.0])
        for mode in "overall", "average":
            ref_precisions = np.zeros_like(all_thr)
            ref_recalls = np.zeros_like(all_thr)

            for i, thr in enumerate(all_thr):

                lims2, _, I2 = evaluation.filter_range_results(
                    new_lims, new_D, new_I, thr)

                prec, recall = evaluation.range_PR(ref_lims,
                                                   ref_I,
                                                   lims2,
                                                   I2,
                                                   mode=mode)

                ref_precisions[i] = prec
                ref_recalls[i] = recall

            precisions, recalls = evaluation.range_PR_multiple_thresholds(
                ref_lims, ref_I, new_lims, new_D, new_I, all_thr, mode=mode)

            np.testing.assert_array_almost_equal(ref_precisions, precisions)
            np.testing.assert_array_almost_equal(ref_recalls, recalls)