def test_custom(self): import faiss custom_index = faiss.IndexFlat(5) index = FaissIndex(custom_index=custom_index) index.add_vectors(np.eye(5, dtype=np.float32)) self.assertIsInstance(index.faiss_index, faiss.IndexFlat)
def test_factory(self): index = FaissIndex(string_factory="Flat") index.add_vectors(np.eye(5, dtype=np.float32)) self.assertIsInstance(index.faiss_index, faiss.IndexFlat) index = FaissIndex(string_factory="LSH") index.add_vectors(np.eye(5, dtype=np.float32)) self.assertIsInstance(index.faiss_index, faiss.IndexLSH)
def test_flat_ip(self): import faiss index = FaissIndex(metric_type=faiss.METRIC_INNER_PRODUCT) # add vectors index.add_vectors(np.eye(5, dtype=np.float32)) self.assertIsNotNone(index.faiss_index) self.assertEqual(index.faiss_index.ntotal, 5) index.add_vectors(np.zeros((5, 5), dtype=np.float32)) self.assertEqual(index.faiss_index.ntotal, 10) # single query query = np.zeros(5, dtype=np.float32) query[1] = 1 scores, indices = index.search(query) self.assertGreater(scores[0], 0) self.assertEqual(indices[0], 1) # batched queries queries = np.eye(5, dtype=np.float32)[::-1] total_scores, total_indices = index.search_batch(queries) best_scores = [scores[0] for scores in total_scores] best_indices = [indices[0] for indices in total_indices] self.assertGreater(np.min(best_scores), 0) self.assertListEqual([4, 3, 2, 1, 0], best_indices)
def test_factory(self): import faiss index = FaissIndex(string_factory="Flat") index.add_vectors(np.eye(5, dtype=np.float32)) self.assertIsInstance(index.faiss_index, faiss.IndexFlat) index = FaissIndex(string_factory="LSH") index.add_vectors(np.eye(5, dtype=np.float32)) self.assertIsInstance(index.faiss_index, faiss.IndexLSH) with self.assertRaises(ValueError): _ = FaissIndex(string_factory="Flat", custom_index=faiss.IndexFlat(5))
def test_serialization(self): index = FaissIndex(metric_type=faiss.METRIC_INNER_PRODUCT) index.add_vectors(np.eye(5, dtype=np.float32)) with tempfile.NamedTemporaryFile() as tmp_file: index.save(tmp_file.name) index = FaissIndex.load(tmp_file.name) query = np.zeros(5, dtype=np.float32) query[1] = 1 scores, indices = index.search(query) self.assertGreater(scores[0], 0) self.assertEqual(indices[0], 1)
def test_serialization(self): import faiss index = FaissIndex(metric_type=faiss.METRIC_INNER_PRODUCT) index.add_vectors(np.eye(5, dtype=np.float32)) # Setting delete=False and unlinking manually is not pretty... but it is required on Windows to # ensure somewhat stable behaviour. If we don't, we get PermissionErrors. This is an age-old issue. # see https://bugs.python.org/issue14243 and # https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file/23212515 with tempfile.NamedTemporaryFile(delete=False) as tmp_file: index.save(tmp_file.name) index = FaissIndex.load(tmp_file.name) os.unlink(tmp_file.name) query = np.zeros(5, dtype=np.float32) query[1] = 1 scores, indices = index.search(query) self.assertGreater(scores[0], 0) self.assertEqual(indices[0], 1)