Ejemplo n.º 1
0
    def subtest(self, mt):
        d = 32
        xt, xb, xq = get_dataset_2(d, 1000, 2000, 200)
        nlist = 64

        gt_index = faiss.IndexFlat(d, mt)
        gt_index.add(xb)
        gt_D, gt_I = gt_index.search(xq, 10)
        quantizer = faiss.IndexFlat(d, mt)
        for by_residual in True, False:

            index = faiss.IndexIVFPQ(quantizer, d, nlist, 4, 8)
            index.metric_type = mt
            index.by_residual = by_residual
            if by_residual:
                # perform cheap polysemous training
                index.do_polysemous_training = True
                pt = faiss.PolysemousTraining()
                pt.n_iter = 50000
                pt.n_redo = 1
                index.polysemous_training = pt

            index.train(xt)
            index.add(xb)
            index.nprobe = 4
            D, I = index.search(xq, 10)

            ninter = faiss.eval_intersection(I, gt_I)
            print('(%d, %s): %d, ' % (mt, by_residual, ninter))

            assert ninter >= self.ref_results[mt, by_residual] - 2

            index.use_precomputed_table = 0
            D2, I2 = index.search(xq, 10)
            assert np.all(I == I2)

            if by_residual:

                index.use_precomputed_table = 1
                index.polysemous_ht = 20
                D, I = index.search(xq, 10)
                ninter = faiss.eval_intersection(I, gt_I)
                print('(%d, %s, %d): %d, ' %
                      (mt, by_residual, index.polysemous_ht, ninter))

                # polysemous behaves bizarrely on ARM
                assert (
                    ninter >=
                    self.ref_results[mt, by_residual, index.polysemous_ht] - 4)
Ejemplo n.º 2
0
    def subtest(self, mt):
        d = 32
        xt, xb, xq = get_dataset_2(d, 2000, 1000, 200)
        nlist = 64

        gt_index = faiss.IndexFlat(d, mt)
        gt_index.add(xb)
        gt_D, gt_I = gt_index.search(xq, 10)
        quantizer = faiss.IndexFlat(d, mt)
        for by_residual in True, False:

            index = faiss.IndexIVFPQ(
                quantizer, d, nlist, 4, 8)
            index.metric_type = mt
            index.by_residual = by_residual
            if by_residual:
                # perform cheap polysemous training
                index.do_polysemous_training = True
                pt = faiss.PolysemousTraining()
                pt.n_iter = 50000
                pt.n_redo = 1
                index.polysemous_training = pt

            index.train(xt)
            index.add(xb)
            index.nprobe = 4
            D, I = index.search(xq, 10)

            ninter = faiss.eval_intersection(I, gt_I)
            print('(%d, %s): %d, ' % (mt, by_residual, ninter))

            assert abs(ninter - self.ref_results[mt, by_residual]) <= 3

            index.use_precomputed_table = 0
            D2, I2 = index.search(xq, 10)
            assert np.all(I == I2)

            if by_residual:

                index.use_precomputed_table = 1
                index.polysemous_ht = 20
                D, I = index.search(xq, 10)
                ninter = faiss.eval_intersection(I, gt_I)
                print('(%d, %s, %d): %d, ' % (
                    mt, by_residual, index.polysemous_ht, ninter))

                # polysemous behaves bizarrely on ARM
                assert (ninter >= self.ref_results[
                    mt, by_residual, index.polysemous_ht] - 4)

            # also test range search

            if mt == faiss.METRIC_INNER_PRODUCT:
                radius = float(D[:, -1].max())
            else:
                radius = float(D[:, -1].min())
            print('radius', radius)

            lims, D3, I3 = index.range_search(xq, radius)
            ntot = ndiff = 0
            for i in range(len(xq)):
                l0, l1 = lims[i], lims[i + 1]
                Inew = set(I3[l0:l1])
                if mt == faiss.METRIC_INNER_PRODUCT:
                    mask = D2[i] > radius
                else:
                    mask = D2[i] < radius
                Iref = set(I2[i, mask])
                ndiff += len(Inew ^ Iref)
                ntot += len(Iref)
            print('ndiff %d / %d' % (ndiff, ntot))
            assert ndiff < ntot * 0.02