Ejemplo n.º 1
0
    def test_ivf_hnsw(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryHNSW(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5  # quiet warning
        index.nprobe = 4
        index.train(self.xt)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        _, tmpnam = tempfile.mkstemp()

        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            D2, I2 = index2.search(self.xq, 3)

            assert (I2 == I).all()
            assert (D2 == D).all()

        finally:
            os.remove(tmpnam)
Ejemplo n.º 2
0
    def test_hnsw_exact_distances(self):
        d = self.xq.shape[1] * 8
        nq = self.xq.shape[0]

        index = faiss.IndexBinaryHNSW(d, 16)
        index.add(self.xb)
        Dists, Ids = index.search(self.xq, 3)

        for i in range(nq):
            for j, dj in zip(Ids[i], Dists[i]):
                ref_dis = binary_dis(self.xq[i], self.xb[j])
                self.assertEqual(dj, ref_dis)
Ejemplo n.º 3
0
    def test_hnsw(self):
        d = self.xq.shape[1] * 8

        # NOTE(hoss): Ensure the HNSW construction is deterministic.
        nthreads = faiss.omp_get_max_threads()
        faiss.omp_set_num_threads(1)

        index_hnsw_float = faiss.IndexHNSWFlat(d, 16)
        index_hnsw_ref = faiss.IndexBinaryFromFloat(index_hnsw_float)

        index_hnsw_bin = faiss.IndexBinaryHNSW(d, 16)

        index_hnsw_ref.add(self.xb)
        index_hnsw_bin.add(self.xb)

        faiss.omp_set_num_threads(nthreads)

        Dref, Iref = index_hnsw_ref.search(self.xq, 3)
        Dbin, Ibin = index_hnsw_bin.search(self.xq, 3)

        self.assertTrue((Dref == Dbin).all())
Ejemplo n.º 4
0
    def test_hnsw(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryHNSW(d)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        _, tmpnam = tempfile.mkstemp()

        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            D2, I2 = index2.search(self.xq, 3)

            assert (I2 == I).all()
            assert (D2 == D).all()

        finally:
            os.remove(tmpnam)