Ejemplo n.º 1
0
    def test_L2(self):
        swig_ptr = faiss.swig_ptr
        x = faiss.rand((100, 10), seed=1)
        y = faiss.rand((200, 10), seed=2)
        ix = faiss.randint(50, vmax=100)
        iy = faiss.randint(50, vmax=200)
        dis = np.empty(50, dtype='float32')
        faiss.pairwise_indexed_L2sqr(10, 50, swig_ptr(x), swig_ptr(ix),
                                     swig_ptr(y), swig_ptr(iy), swig_ptr(dis))

        for i in range(50):
            assert np.allclose(dis[i], ((x[ix[i]] - y[iy[i]])**2).sum())
Ejemplo n.º 2
0
    def test_IP(self):
        swig_ptr = faiss.swig_ptr
        x = faiss.rand((100, 10), seed=1)
        y = faiss.rand((200, 10), seed=2)
        ix = faiss.randint(50, vmax=100)
        iy = faiss.randint(50, vmax=200)
        dis = np.empty(50, dtype='float32')
        faiss.pairwise_indexed_inner_product(10, 50, swig_ptr(x), swig_ptr(ix),
                                             swig_ptr(y), swig_ptr(iy),
                                             swig_ptr(dis))

        for i in range(50):
            assert np.allclose(dis[i], np.dot(x[ix[i]], y[iy[i]]))
Ejemplo n.º 3
0
    def run_test(self, keep_max):
        nq = 100
        nb = 1000
        restab = faiss.rand((nq, nb), 123)
        ids = faiss.randint((nq, nb), 1324, 10000)
        all_rh = {}
        for nstep in 1, 3:
            rh = faiss.ResultHeap(nq, 10, keep_max=keep_max)
            for i in range(nstep):
                i0, i1 = i * nb // nstep, (i + 1) * nb // nstep
                D = restab[:, i0:i1].copy()
                I = ids[:, i0:i1].copy()
                rh.add_result(D, I)
            rh.finalize()
            if keep_max:
                assert np.all(rh.D[:, :-1] >= rh.D[:, 1:])
            else:
                assert np.all(rh.D[:, :-1] <= rh.D[:, 1:])
            all_rh[nstep] = rh

        np.testing.assert_equal(all_rh[1].D, all_rh[3].D)
        np.testing.assert_equal(all_rh[1].I, all_rh[3].I)
Ejemplo n.º 4
0
 def test_randint(self):
     x = faiss.randint(20000, vmax=100)
     assert np.all(x >= 0) and np.all(x < 100)
     c = np.bincount(x, minlength=100)
     print(c)
     assert c.max() - c.min() < 50 * 2