def test_clipping(self): """ verify that a clipped residual quantizer gives the same code prefix + suffix as the full RQ """ ds = datasets.SyntheticDataset(32, 1000, 100, 0) rq = faiss.ResidualQuantizer(ds.d, 5, 4) rq.train_type = faiss.ResidualQuantizer.Train_default rq.max_beam_size = 5 rq.train(ds.get_train()) rq.max_beam_size = 1 # is not he same for a large beam size codes = rq.compute_codes(ds.get_database()) rq2 = faiss.ResidualQuantizer(ds.d, 2, 4) rq2.initialize_from(rq) self.assertEqual(rq2.M, 2) # verify that the beginning of the codes are the same codes2 = rq2.compute_codes(ds.get_database()) rq3 = faiss.ResidualQuantizer(ds.d, 3, 4) rq3.initialize_from(rq, 2) self.assertEqual(rq3.M, 3) codes3 = rq3.compute_codes(ds.get_database() - rq2.decode(codes2)) # verify that prefixes are the same for i in range(ds.nb): print(i, ds.nb) br = faiss.BitstringReader(faiss.swig_ptr(codes[i]), rq.code_size) br2 = faiss.BitstringReader(faiss.swig_ptr(codes2[i]), rq2.code_size) self.assertEqual(br.read(rq2.tot_bits), br2.read(rq2.tot_bits)) br3 = faiss.BitstringReader(faiss.swig_ptr(codes3[i]), rq3.code_size) self.assertEqual(br.read(rq3.tot_bits), br3.read(rq3.tot_bits))
def unpack_codes(rq, packed_codes): nbits = faiss.vector_to_array(rq.nbits) if np.all(nbits == 8): return packed_codes.astype("uint32") nbits = [int(x) for x in nbits] nb = len(nbits) n, code_size = packed_codes.shape codes = np.zeros((n, nb), dtype="uint32") for i in range(n): br = faiss.BitstringReader(faiss.swig_ptr(packed_codes[i]), code_size) for j, nbi in enumerate(nbits): codes[i, j] = br.read(nbi) return codes
def test_rw(self): rs = np.random.RandomState(1234) nbyte = 1000 sz = 0 bs = np.ones(nbyte, dtype='uint8') bw = faiss.BitstringWriter(swig_ptr(bs), nbyte) if False: ctrl = [(7, 0x35), (13, 0x1d74)] for nbit, x in ctrl: bw.write(x, nbit) else: ctrl = [] while True: nbit = int(1 + 62 * rs.rand()**4) if sz + nbit > nbyte * 8: break x = rs.randint(1 << nbit) bw.write(x, nbit) ctrl.append((nbit, x)) sz += nbit bignum = 0 sz = 0 for nbit, x in ctrl: bignum |= x << sz sz += nbit for i in range(nbyte): self.assertTrue(((bignum >> (i * 8)) & 255) == bs[i]) for i in range(nbyte): print(bin(bs[i] + 256)[3:], end=' ') print() br = faiss.BitstringReader(swig_ptr(bs), nbyte) for nbit, xref in ctrl: xnew = br.read(nbit) print('nbit %d xref %x xnew %x' % (nbit, xref, xnew)) self.assertTrue(xnew == xref)