Exemple #1
0
    def test_equiv_sh(self):
        """ make sure that the IVFSpectralHash sa_encode function gives the same
        result as the concatenated RQ + LSH index sa_encode """
        ds = SyntheticDataset(32, 500, 100, 0)
        index1 = faiss.index_factory(ds.d, "RQ1x4,Refine(ITQ16,LSH)")
        index1.train(ds.get_train())

        # reproduce this in an IndexIVFSpectralHash
        coarse_quantizer = faiss.IndexFlat(ds.d)
        rq = faiss.downcast_index(index1.base_index).rq
        centroids = get_additive_quantizer_codebooks(rq)[0]
        coarse_quantizer.add(centroids)

        encoder = faiss.downcast_index(index1.refine_index)

        # larger than the magnitude of the vectors
        # negative because otherwise the bits are flipped
        period = -100000.0

        index2 = faiss.IndexIVFSpectralHash(coarse_quantizer, ds.d,
                                            coarse_quantizer.ntotal,
                                            encoder.sa_code_size() * 8, period)

        # replace with the vt of the encoder. Binarization is performed by
        # the IndexIVFSpectralHash itself
        index2.replace_vt(encoder)

        codes1 = index1.sa_encode(ds.get_database())
        codes2 = index2.sa_encode(ds.get_database())

        np.testing.assert_array_equal(codes1, codes2)
    def test_sh(self):
        d = 32
        xt, xb, xq = get_dataset_2(d, 2000, 1000, 200)
        nlist, nprobe = 1, 1

        gt_index = faiss.IndexFlatL2(d)
        gt_index.add(xb)
        gt_D, gt_I = gt_index.search(xq, 10)

        for nbit in 32, 64, 128:
            quantizer = faiss.IndexFlatL2(d)

            index_lsh = faiss.IndexLSH(d, nbit, True)
            index_lsh.add(xb)
            D, I = index_lsh.search(xq, 10)
            ninter = faiss.eval_intersection(I, gt_I)

            print('LSH baseline: %d' % ninter)

            for period in 10.0, 1.0:

                for tt in 'global centroid centroid_half median'.split():
                    index = faiss.IndexIVFSpectralHash(quantizer, d, nlist,
                                                       nbit, period)
                    index.nprobe = nprobe
                    index.threshold_type = getattr(
                        faiss.IndexIVFSpectralHash,
                        'Thresh_' + tt
                    )

                    index.train(xt)
                    index.add(xb)
                    D, I = index.search(xq, 10)

                    ninter = faiss.eval_intersection(I, gt_I)
                    key = (nbit, tt, period)

                    print('(%d, %s, %g): %d, ' % (nbit, repr(tt), period, ninter))
                    assert abs(ninter - self.ref_results[key]) <= 12