예제 #1
0
    def test_custom(self):
        import faiss

        custom_index = faiss.IndexFlat(5)
        index = FaissIndex(custom_index=custom_index)
        index.add_vectors(np.eye(5, dtype=np.float32))
        self.assertIsInstance(index.faiss_index, faiss.IndexFlat)
예제 #2
0
 def test_factory(self):
     index = FaissIndex(string_factory="Flat")
     index.add_vectors(np.eye(5, dtype=np.float32))
     self.assertIsInstance(index.faiss_index, faiss.IndexFlat)
     index = FaissIndex(string_factory="LSH")
     index.add_vectors(np.eye(5, dtype=np.float32))
     self.assertIsInstance(index.faiss_index, faiss.IndexLSH)
예제 #3
0
    def test_flat_ip(self):
        import faiss

        index = FaissIndex(metric_type=faiss.METRIC_INNER_PRODUCT)

        # add vectors
        index.add_vectors(np.eye(5, dtype=np.float32))
        self.assertIsNotNone(index.faiss_index)
        self.assertEqual(index.faiss_index.ntotal, 5)
        index.add_vectors(np.zeros((5, 5), dtype=np.float32))
        self.assertEqual(index.faiss_index.ntotal, 10)

        # single query
        query = np.zeros(5, dtype=np.float32)
        query[1] = 1
        scores, indices = index.search(query)
        self.assertGreater(scores[0], 0)
        self.assertEqual(indices[0], 1)

        # batched queries
        queries = np.eye(5, dtype=np.float32)[::-1]
        total_scores, total_indices = index.search_batch(queries)
        best_scores = [scores[0] for scores in total_scores]
        best_indices = [indices[0] for indices in total_indices]
        self.assertGreater(np.min(best_scores), 0)
        self.assertListEqual([4, 3, 2, 1, 0], best_indices)
예제 #4
0
    def test_factory(self):
        import faiss

        index = FaissIndex(string_factory="Flat")
        index.add_vectors(np.eye(5, dtype=np.float32))
        self.assertIsInstance(index.faiss_index, faiss.IndexFlat)
        index = FaissIndex(string_factory="LSH")
        index.add_vectors(np.eye(5, dtype=np.float32))
        self.assertIsInstance(index.faiss_index, faiss.IndexLSH)
        with self.assertRaises(ValueError):
            _ = FaissIndex(string_factory="Flat",
                           custom_index=faiss.IndexFlat(5))
예제 #5
0
 def test_serialization(self):
     index = FaissIndex(metric_type=faiss.METRIC_INNER_PRODUCT)
     index.add_vectors(np.eye(5, dtype=np.float32))
     with tempfile.NamedTemporaryFile() as tmp_file:
         index.save(tmp_file.name)
         index = FaissIndex.load(tmp_file.name)
     query = np.zeros(5, dtype=np.float32)
     query[1] = 1
     scores, indices = index.search(query)
     self.assertGreater(scores[0], 0)
     self.assertEqual(indices[0], 1)
예제 #6
0
    def test_serialization(self):
        import faiss

        index = FaissIndex(metric_type=faiss.METRIC_INNER_PRODUCT)
        index.add_vectors(np.eye(5, dtype=np.float32))

        # Setting delete=False and unlinking manually is not pretty... but it is required on Windows to
        # ensure somewhat stable behaviour. If we don't, we get PermissionErrors. This is an age-old issue.
        # see https://bugs.python.org/issue14243 and
        # https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file/23212515
        with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
            index.save(tmp_file.name)
            index = FaissIndex.load(tmp_file.name)
        os.unlink(tmp_file.name)

        query = np.zeros(5, dtype=np.float32)
        query[1] = 1
        scores, indices = index.search(query)
        self.assertGreater(scores[0], 0)
        self.assertEqual(indices[0], 1)