Ejemplo n.º 1
0
    def test_clipping(self):
        """ verify that a clipped residual quantizer gives the same
        code prefix + suffix as the full RQ """
        ds = datasets.SyntheticDataset(32, 1000, 100, 0)

        rq = faiss.ResidualQuantizer(ds.d, 5, 4)
        rq.train_type = faiss.ResidualQuantizer.Train_default
        rq.max_beam_size = 5
        rq.train(ds.get_train())

        rq.max_beam_size = 1  # is not he same for a large beam size
        codes = rq.compute_codes(ds.get_database())

        rq2 = faiss.ResidualQuantizer(ds.d, 2, 4)
        rq2.initialize_from(rq)
        self.assertEqual(rq2.M, 2)
        # verify that the beginning of the codes are the same
        codes2 = rq2.compute_codes(ds.get_database())

        rq3 = faiss.ResidualQuantizer(ds.d, 3, 4)
        rq3.initialize_from(rq, 2)
        self.assertEqual(rq3.M, 3)
        codes3 = rq3.compute_codes(ds.get_database() - rq2.decode(codes2))

        # verify that prefixes are the same
        for i in range(ds.nb):
            print(i, ds.nb)
            br = faiss.BitstringReader(faiss.swig_ptr(codes[i]), rq.code_size)
            br2 = faiss.BitstringReader(faiss.swig_ptr(codes2[i]),
                                        rq2.code_size)
            self.assertEqual(br.read(rq2.tot_bits), br2.read(rq2.tot_bits))
            br3 = faiss.BitstringReader(faiss.swig_ptr(codes3[i]),
                                        rq3.code_size)
            self.assertEqual(br.read(rq3.tot_bits), br3.read(rq3.tot_bits))
Ejemplo n.º 2
0
def unpack_codes(rq, packed_codes):
    nbits = faiss.vector_to_array(rq.nbits)
    if np.all(nbits == 8):
        return packed_codes.astype("uint32")
    nbits = [int(x) for x in nbits]
    nb = len(nbits)
    n, code_size = packed_codes.shape
    codes = np.zeros((n, nb), dtype="uint32")
    for i in range(n):
        br = faiss.BitstringReader(faiss.swig_ptr(packed_codes[i]), code_size)
        for j, nbi in enumerate(nbits):
            codes[i, j] = br.read(nbi)
    return codes
Ejemplo n.º 3
0
    def test_rw(self):
        rs = np.random.RandomState(1234)
        nbyte = 1000
        sz = 0

        bs = np.ones(nbyte, dtype='uint8')
        bw = faiss.BitstringWriter(swig_ptr(bs), nbyte)

        if False:
            ctrl = [(7, 0x35), (13, 0x1d74)]
            for nbit, x in ctrl:
                bw.write(x, nbit)
        else:
            ctrl = []
            while True:
                nbit = int(1 + 62 * rs.rand()**4)
                if sz + nbit > nbyte * 8:
                    break
                x = rs.randint(1 << nbit)
                bw.write(x, nbit)
                ctrl.append((nbit, x))
                sz += nbit

        bignum = 0
        sz = 0
        for nbit, x in ctrl:
            bignum |= x << sz
            sz += nbit

        for i in range(nbyte):
            self.assertTrue(((bignum >> (i * 8)) & 255) == bs[i])

        for i in range(nbyte):
            print(bin(bs[i] + 256)[3:], end=' ')
        print()

        br = faiss.BitstringReader(swig_ptr(bs), nbyte)

        for nbit, xref in ctrl:
            xnew = br.read(nbit)
            print('nbit %d xref %x xnew %x' % (nbit, xref, xnew))
            self.assertTrue(xnew == xref)