def test_hash_and_multihash(self): d = 128 nq = 100 nb = 2000 (_, xb, xq) = make_binary_dataset(d, 0, nb, nq) index_ref = faiss.IndexBinaryFlat(d) index_ref.add(xb) k = 10 Dref, Iref = index_ref.search(xq, k) nfound = {} for nh in 0, 1, 3, 5: for nbit in 4, 7: if nh == 0: index = faiss.IndexBinaryHash(d, nbit) else: index = faiss.IndexBinaryMultiHash(d, nh, nbit) index.add(xb) index.nflip = 2 Dnew, Inew = index.search(xq, k) nf = 0 for i in range(nq): ref = Iref[i] new = Inew[i] snew = set(new) # no duplicates self.assertTrue(len(new) == len(snew)) nf += len(set(ref) & snew) print('nfound', nh, nbit, nf) nfound[(nh, nbit)] = nf self.assertGreater(nfound[(nh, 4)], nfound[(nh, 7)]) # test serialization index2 = faiss.deserialize_index_binary( faiss.serialize_index_binary(index)) D2, I2 = index2.search(xq, k) np.testing.assert_array_equal(Inew, I2) np.testing.assert_array_equal(Dnew, D2) print('nfound=', nfound) self.assertGreater(3, abs(nfound[(0, 7)] - nfound[(1, 7)])) self.assertGreater(nfound[(3, 7)], nfound[(1, 7)]) self.assertGreater(nfound[(5, 7)], nfound[(3, 7)])
def subtest_result_order(self, nh): d = 128 nq = 10 nb = 200 (_, xb, xq) = make_binary_dataset(d, 0, nb, nq) nbit = 10 if nh == 0: index = faiss.IndexBinaryHash(d, nbit) else: index = faiss.IndexBinaryMultiHash(d, nh, nbit) index.add(xb) index.nflip = 5 k = 10 Do, Io = index.search(xq, k) self.assertTrue(np.all(Do[:, 1:] >= Do[:, :-1]))
def test_hash(self): d = 128 nq = 100 nb = 2000 (_, xb, xq) = make_binary_dataset(d, 0, nb, nq) index_ref = faiss.IndexBinaryFlat(d) index_ref.add(xb) radius = 55 Lref, Dref, Iref = index_ref.range_search(xq, radius) print("nb res: ", Lref[-1]) index = faiss.IndexBinaryHash(d, 10) index.add(xb) # index.display() nfound = [] ndis = [] stats = faiss.cvar.indexBinaryHash_stats for n_bitflips in range(index.b + 1): index.nflip = n_bitflips stats.reset() Lnew, Dnew, Inew = index.range_search(xq, radius) for i in range(nq): ref = Iref[Lref[i]:Lref[i + 1]] new = Inew[Lnew[i]:Lnew[i + 1]] snew = set(new) # no duplicates self.assertTrue(len(new) == len(snew)) # subset of real results self.assertTrue(snew <= set(ref)) nfound.append(Lnew[-1]) ndis.append(stats.ndis) print('nfound=', nfound) print('ndis=', ndis) nfound = np.array(nfound) self.assertTrue(nfound[-1] == Lref[-1]) self.assertTrue(np.all(nfound[1:] >= nfound[:-1]))
def registry_index(way_index): # assert way_index in range(len(DIMENSIONS)) # prepare index dimensions = DIMENSIONS[way_index] if isAddPhash: dimensions += PHASH_X * PHASH_Y # https://github.com/facebookresearch/faiss/wiki/Binary-indexes # https://github.com/facebookresearch/faiss/blob/22b7876ef5540b85feee173aa3182a2f37dc98f6/tests/test_index_binary.py#L213 if way_index != 3: # nbits/8 https://github.com/facebookresearch/faiss/wiki/Faiss-indexes#relationship-with-lsh index = faiss.IndexBinaryHash(dimensions * 8, 1) else: index = faiss.index_factory(dimensions, INDEX_KEY) if USE_GPU: print("Use GPU...") res = faiss.StandardGpuResources() index = faiss.index_cpu_to_gpu(res, 0, index) # start training images_list = iterate_files(train_image_dir) # may change # prepare ids ids_count = 0 index_defaultdict = defaultdict(list) # ids = None # features = np.matrix([]) features = [] ids = [] cla_name_temp = parser_name(images_list[0]) way = get_way(w_index=way_index) # ORB , surf, and so on for file_name in images_list: cla_name = parser_name(file_name) ret, feature = way_feature(way, file_name) numf = feature.shape[0] if way_index == 3 and FEATURE_CLIP: numf = FEATURE_CLIP if feature.shape[ 0] > FEATURE_CLIP else feature.shape[0] # feature = feature[:FEATURE_CLIP, :] choosed_fea = sample(range(feature.shape[0]), numf) feature = feature[choosed_fea, :] if ret == 0 and feature.any(): if cla_name != cla_name_temp: ids_count += 1 # change when same img not only one cla_name_temp = cla_name # record id and path # image_dict = {ids_count: (file_name, feature)} # image_dict = {ids_count: file_name} # smaller than above index_defaultdict[ids_count].append( file_name ) # here in registry, on_id may have more than one img(obj) # print(way_feature.shape[0]) # ids_list = np.linspace(ids_count, ids_count, num=feature.shape[0], dtype="int64") ids_list = np.linspace(ids_count, ids_count, num=numf, dtype="int64") print(feature.shape, ids_count, len(ids_list), ids_list.shape) features.append(feature) ids.append(ids_list) # if features.any(): # # print(feature[0].dtype) # uint8 # features = np.vstack((features, feature)) # <class 'numpy.matrix'> # # print(feature.shape) # ids = np.hstack((ids, ids_list)) # None --> empty matrix # print(ids.dtype, ids) # else: # all feature is 0 # features = feature # ids = ids_list # print(ids, ids.dtype) # int64 # print(index.is_trained) # print(features.shape, ids.shape) # if ids_count % 500 == 499: # optim # if not index.is_trained: # index.train(features) # index.add_with_ids(features, ids) # https://github.com/facebookresearch/faiss/issues/856 # ids = None # features = np.matrix([]) # print(len(features), len(ids)) features = np.vstack(features) ids = np.hstack(ids) print(features.shape, ids.shape) if features.any(): if not index.is_trained: index.train(features) index.add_with_ids(features, ids) # change # save index if WAY_INDEX == 3: faiss.write_index(index, index_path) else: faiss.write_index_binary(index, index_path) # save ids if not os.path.exists(ids_path): with open(ids_path, 'wb+') as f: try: pickle.dump(index_defaultdict, f, True) except EnvironmentError as e: logging.error('Failed to save index file error:[{}]'.format(e)) except RuntimeError as v: logging.error('Failed to save index file error:[{}]'.format(v)) print('Registry completed')
passage_db = PassageDB(args.passage_db_file) embedding_data = joblib.load(args.embedding_file, mmap_mode="r") ids, embeddings = embedding_data["ids"], embedding_data["embeddings"] dim_size = embeddings.shape[1] logger.info("Building index...") if embeddings.dtype == np.uint8: if args.binary_to_float: embeddings = np.unpackbits(embeddings).reshape( -1, dim_size * 8).astype(np.float32) embeddings = embeddings * 2 - 1 base_index = faiss.IndexFlatIP(dim_size * 8) index = FaissIndex.build(ids, embeddings, base_index) elif args.use_binary_hash: base_index = faiss.IndexBinaryHash(dim_size * 8, args.hash_num_bits) index = FaissBinaryIndex.build(ids, embeddings, base_index) else: base_index = faiss.IndexBinaryFlat(dim_size * 8) index = FaissBinaryIndex.build(ids, embeddings, base_index) elif args.use_hnsw: base_index = faiss.IndexHNSWFlat(dim_size + 1, args.hnsw_store_n) base_index.hnsw.efSearch = args.hnsw_ef_search base_index.hnsw.efConstruction = args.hnsw_ef_construction index = FaissHNSWIndex.build(ids, embeddings, base_index) else: base_index = faiss.IndexFlatIP(dim_size) index = FaissIndex.build(ids, embeddings, base_index)