Ejemplo n.º 1
0
    def test_hash_and_multihash(self):
        d = 128
        nq = 100
        nb = 2000

        (_, xb, xq) = make_binary_dataset(d, 0, nb, nq)

        index_ref = faiss.IndexBinaryFlat(d)
        index_ref.add(xb)
        k = 10
        Dref, Iref = index_ref.search(xq, k)

        nfound = {}
        for nh in 0, 1, 3, 5:

            for nbit in 4, 7:
                if nh == 0:
                    index = faiss.IndexBinaryHash(d, nbit)
                else:
                    index = faiss.IndexBinaryMultiHash(d, nh, nbit)
                index.add(xb)
                index.nflip = 2
                Dnew, Inew = index.search(xq, k)
                nf = 0
                for i in range(nq):
                    ref = Iref[i]
                    new = Inew[i]
                    snew = set(new)
                    # no duplicates
                    self.assertTrue(len(new) == len(snew))
                    nf += len(set(ref) & snew)
                print('nfound', nh, nbit, nf)
                nfound[(nh, nbit)] = nf
            self.assertGreater(nfound[(nh, 4)], nfound[(nh, 7)])

            # test serialization
            index2 = faiss.deserialize_index_binary(
                faiss.serialize_index_binary(index))

            D2, I2 = index2.search(xq, k)
            np.testing.assert_array_equal(Inew, I2)
            np.testing.assert_array_equal(Dnew, D2)

        print('nfound=', nfound)
        self.assertGreater(3, abs(nfound[(0, 7)] - nfound[(1, 7)]))
        self.assertGreater(nfound[(3, 7)], nfound[(1, 7)])
        self.assertGreater(nfound[(5, 7)], nfound[(3, 7)])
Ejemplo n.º 2
0
 def __getstate__(self):
     data = faiss.serialize_index_binary(self.faiss_index)
     return data