def test_shards(self): k = 32 ref_index = faiss.IndexFlatL2(d) print('ref search') ref_index.add(xb) _Dref, Iref = ref_index.search(xq, k) print(Iref[:5, :6]) shard_index = faiss.IndexShards(d) shard_index_2 = faiss.IndexShards(d, True, False) ni = 3 for i in range(ni): i0 = int(i * nb / ni) i1 = int((i + 1) * nb / ni) index = faiss.IndexFlatL2(d) index.add(xb[i0:i1]) shard_index.add_shard(index) index_2 = faiss.IndexFlatL2(d) irm = faiss.IndexIDMap(index_2) shard_index_2.add_shard(irm) # test parallel add shard_index_2.verbose = True shard_index_2.add(xb) for test_no in range(3): with_threads = test_no == 1 print('shard search test_no = %d' % test_no) if with_threads: remember_nt = faiss.omp_get_max_threads() faiss.omp_set_num_threads(1) shard_index.threaded = True else: shard_index.threaded = False if test_no != 2: _D, I = shard_index.search(xq, k) else: _D, I = shard_index_2.search(xq, k) print(I[:5, :6]) if with_threads: faiss.omp_set_num_threads(remember_nt) ndiff = (I != Iref).sum() print('%d / %d differences' % (ndiff, nq * k)) assert (ndiff < nq * k / 1000.)
def test_replicas(self): d = 32 nq = 100 nb = 200 (_, xb, xq) = make_binary_dataset(d, 0, nb, nq) index_ref = faiss.IndexBinaryFlat(d) index_ref.add(xb) Dref, Iref = index_ref.search(xq, 10) # there is a OpenMP bug in this configuration, so disable threading if sys.platform == "darwin" and "Clang 12" in sys.version: nthreads = faiss.omp_get_max_threads() faiss.omp_set_num_threads(1) else: nthreads = None nrep = 5 index = faiss.IndexBinaryReplicas() for _i in range(nrep): sub_idx = faiss.IndexBinaryFlat(d) sub_idx.add(xb) index.addIndex(sub_idx) D, I = index.search(xq, 10) self.assertTrue((Dref == D).all()) self.assertTrue((Iref == I).all()) index2 = faiss.IndexBinaryReplicas() for _i in range(nrep): sub_idx = faiss.IndexBinaryFlat(d) index2.addIndex(sub_idx) index2.add(xb) D2, I2 = index2.search(xq, 10) if nthreads is not None: faiss.omp_set_num_threads(nthreads) self.assertTrue((Dref == D2).all()) self.assertTrue((Iref == I2).all())
def test_hnsw(self): d = self.xq.shape[1] * 8 # NOTE(hoss): Ensure the HNSW construction is deterministic. nthreads = faiss.omp_get_max_threads() faiss.omp_set_num_threads(1) index_hnsw_float = faiss.IndexHNSWFlat(d, 16) index_hnsw_ref = faiss.IndexBinaryFromFloat(index_hnsw_float) index_hnsw_bin = faiss.IndexBinaryHNSW(d, 16) index_hnsw_ref.add(self.xb) index_hnsw_bin.add(self.xb) faiss.omp_set_num_threads(nthreads) Dref, Iref = index_hnsw_ref.search(self.xq, 3) Dbin, Ibin = index_hnsw_bin.search(self.xq, 3) self.assertTrue((Dref == Dbin).all())
def parse(self): args = self.parser.parse_args() self.check_arguments(args) args.input_arguments = copy.deepcopy(args) args.nranks, args.distributed = distributed.init(args.rank) args.nthreads = int( max(os.cpu_count(), faiss.omp_get_max_threads()) * 0.8) args.nthreads = max(1, args.nthreads // args.nranks) if args.nranks > 1: print_message( f"#> Restricting number of threads for FAISS to {args.nthreads} per process", condition=(args.rank == 0)) faiss.omp_set_num_threads(args.nthreads) Run.init(args.rank, args.root, args.experiment, args.run) Run._log_args(args) Run.info(args.input_arguments.__dict__, '\n') return args
# find operating points for this index opi = params.explore(index, xq, crit) print "[%.3f s] result operating points:" % (time.time() - t0) opi.display() # update best operating points so far op.merge_with(opi, index_key + " ") op_per_key.append((index_key, opi)) if True: # graphical output (to tmp/ subdirectory) fig = pyplot.figure(figsize=(12, 9)) pyplot.xlabel("1-recall at 1") pyplot.ylabel("search time (ms/query, %d threads)" % faiss.omp_get_max_threads()) pyplot.gca().set_yscale('log') pyplot.grid() for i2, opi2 in op_per_key: plot_OperatingPoints(opi2, crit.nq, label=i2, marker='o') # plot_OperatingPoints(op, crit.nq, label = 'best', marker = 'o', color = 'r') pyplot.legend(loc=2) fig.savefig('tmp/demo_auto_tune.png') print "[%.3f s] final result:" % (time.time() - t0) op.display()
# find operating points for this index opi = params.explore(index, xq, crit) print "[%.3f s] result operating points:" % (time.time() - t0) opi.display() # update best operating points so far op.merge_with(opi, index_key + " ") op_per_key.append((index_key, opi)) if True: # graphical output (to tmp/ subdirectory) fig = pyplot.figure(figsize=(12, 9)) pyplot.xlabel("1-recall at 1") pyplot.ylabel("search time (ms/query, %d threads)" % faiss.omp_get_max_threads()) pyplot.gca().set_yscale('log') pyplot.grid() for i2, opi2 in op_per_key: plot_OperatingPoints(opi2, crit.nq, label = i2, marker = 'o') # plot_OperatingPoints(op, crit.nq, label = 'best', marker = 'o', color = 'r') pyplot.legend(loc=2) fig.savefig('tmp/demo_auto_tune.png') print "[%.3f s] final result:" % (time.time() - t0) op.display()
print("precomputed tables size:", precomputed_table_size) ############################################################# # Index is ready ############################################################# xq = sanitize(ds.get_queries()) gt = ds.get_groundtruth(k=args.k) assert gt.shape[1] == args.k, pdb.set_trace() if args.searchthreads != -1: print("Setting nb of threads to", args.searchthreads) faiss.omp_set_num_threads(args.searchthreads) else: print("nb search threads: ", faiss.omp_get_max_threads()) ps = faiss.ParameterSpace() ps.initialize(index) parametersets = args.searchparams if args.inter: header = ( '%-40s inter@%3d time(ms/q) nb distances #runs' % ("parameters", args.k) ) else:
# find operating points for this index opi = params.explore(index, xq, crit) print "[%.3f s] result operating points:" % (time.time() - t0) opi.display() # update best operating points so far op.merge_with(opi, index_key + " ") op_per_key.append((index_key, opi)) if True: # graphical output (to tmp/ subdirectory) fig = pyplot.figure(figsize=(12, 9)) pyplot.xlabel("1-recall at 1") pyplot.ylabel("search time (ms/query, %d threads)" % faiss.omp_get_max_threads()) pyplot.gca().set_yscale('log') pyplot.grid() for i2, opi2 in op_per_key: plot_OperatingPoints(opi2, crit.nq, label = i2, marker = 'o') # plot_OperatingPoints(op, crit.nq, label = 'best', marker = 'o', color = 'r') pyplot.legend(loc=2) fig.savefig('tmp/demo_auto_tune.png') print "[%.3f s] final result:" % (time.time() - t0) op.display()
def test_shards(self): k = 32 ref_index = faiss.IndexFlatL2(d) print('ref search') ref_index.add(xb) _Dref, Iref = ref_index.search(xq, k) print(Iref[:5, :6]) # there is a OpenMP bug in this configuration, so disable threading if sys.platform == "darwin" and "Clang 12" in sys.version: nthreads = faiss.omp_get_max_threads() faiss.omp_set_num_threads(1) else: nthreads = None shard_index = faiss.IndexShards(d) shard_index_2 = faiss.IndexShards(d, True, False) ni = 3 for i in range(ni): i0 = int(i * nb / ni) i1 = int((i + 1) * nb / ni) index = faiss.IndexFlatL2(d) index.add(xb[i0:i1]) shard_index.add_shard(index) index_2 = faiss.IndexFlatL2(d) irm = faiss.IndexIDMap(index_2) shard_index_2.add_shard(irm) # test parallel add shard_index_2.verbose = True shard_index_2.add(xb) for test_no in range(3): with_threads = test_no == 1 print('shard search test_no = %d' % test_no) if with_threads: remember_nt = faiss.omp_get_max_threads() faiss.omp_set_num_threads(1) shard_index.threaded = True else: shard_index.threaded = False if test_no != 2: _D, I = shard_index.search(xq, k) else: _D, I = shard_index_2.search(xq, k) print(I[:5, :6]) if with_threads: faiss.omp_set_num_threads(remember_nt) ndiff = (I != Iref).sum() print('%d / %d differences' % (ndiff, nq * k)) assert(ndiff < nq * k / 1000.) if nthreads is not None: faiss.omp_set_num_threads(nthreads)
def generate_new_ann(args, output_num, checkpoint_path, training_query_positive_id, dev_query_positive_id, query_embcache, passage_embcache): while True: try: config, tokenizer, model = load_model(args, checkpoint_path) break except: time.sleep(60) print("retry loading model") passage_embcache.change_seed(output_num) query_embcache.change_seed(output_num) model.eval() logger.info("***** inference of passages *****") num_batches_per_gpu = len(passage_embcache)//(args.per_gpu_eval_batch_size*dist.get_world_size()) args.corpus_divider = math.max(math.min(args.corpus_divider, 1.0), 0.0) # only run embedding inference for half of the passages to speed up process passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_", passage_embcache, is_query_inference = False, end_batch=num_batches_per_gpu*args.corpus_divider) logger.info("***** Done passage inference *****") pid2ix = {v:k for k, v in enumerate(passage_embedding2id)} # build index partition on each process dim = passage_embedding.shape[1] logger.info('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(32//dist.get_world_size()) print(faiss.omp_get_max_threads()) if args.flat_index: cpu_index = faiss.IndexFlatIP(dim) else: cpu_index = faiss.index_factory(dim, "IVF8192,Flat") cpu_index.train(passage_embedding) cpu_index.add(passage_embedding) logger.info("***** Done training Index *****") cpu_index.nprobe = 50 logger.info("***** Done building ANN Index *****") dist.barrier() if args.flat_index: flat_index = cpu_index else: flat_index = faiss.IndexFlatIP(dim) flat_index.add(passage_embedding) logger.info("**** Done building flat index *****") dev_ndcg, num_queries_dev = 0.0, 0 train_data_output_path = os.path.join(args.output_dir, "ann_training_data_" + str(output_num)) debug_output_path = os.path.join(args.output_dir, "ann_debug_"+ str(output_num)) with open(train_data_output_path, 'w') as f, open(debug_output_path, "w", encoding="utf-8") as debug_g: chunk_factor = args.ann_chunk_factor if chunk_factor <= 0: chunk_factor = 1 num_batches_per_gpu = len(query_embcache)//(args.per_gpu_eval_batch_size*dist.get_world_size()) batches_per_chunk = num_batches_per_gpu // chunk_factor end_idx = batches_per_chunk print("End idx:", end_idx) inference_dataset = StreamingDataset(query_embcache, GetProcessingFn(args, query=True)) inference_dataloader = DataLoader(inference_dataset, batch_size=args.per_gpu_eval_batch_size) out_train_list = [] for m_batch, batch in tqdm(enumerate(inference_dataloader), desc="Inferencing", disable=args.local_rank not in [-1, 0], position=0): if m_batch>end_idx: break qids = batch[3].detach().numpy() #[#B] batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): inputs = {"input_ids": batch[0].long(), "attention_mask": batch[1].long()} embs = model.module.query_emb(**inputs) embs = embs.detach().cpu().numpy() # take only queries with positive passage, then collect pos_idx = [i for i,qid in enumerate(qids) if qid in training_query_positive_id] query_embedding = embs[pos_idx] q_chunk_qid = qids[pos_idx] tmp_obj = {"emb": query_embedding, "id": q_chunk_qid} objs = all_gather(tmp_obj) query_embedding = concat_key(objs, "emb", axis=0) q_chunk_qid = concat_key(objs, "id", axis=0) if m_batch==0: print(query_embedding.shape, q_chunk_qid.shape) D, I = cpu_index.search(query_embedding, top_k) I = passage_embedding2id[I] if m_batch%100==0: logger.info(f"***** Done querying ANN Index chunk {m_batch}*****") knn_pkl = {"D": D, "I": I} all_knn_list = all_gather(knn_pkl) del knn_pkl if m_batch==0: # we only do flat_index search to debug Df, If = flat_index.search(query_embedding, top_k) If = passage_embedding2id[If] knn_pkl_flat = {"D": Df, "I": If} all_knn_list_flat = all_gather(knn_pkl_flat) del knn_pkl_flat if is_first_worker(): D_merged = concat_key(all_knn_list, "D", axis=1) I_merged = concat_key(all_knn_list, "I", axis=1) # idx = np.argsort(D_merged, axis=1)[:, ::-1] if not args.ann_measure_topk_mrr: shuffled_ix = np.random.permutation(I_merged.shape[1])[:args.negative_sample + 1] sub_I = np.take(I_merged, shuffled_ix, axis=1) else: top_idx = np.argsort(D_merged, axis=1)[:args.negative_sample + 1] sub_I = np.take_along_axis(I_merged, top_idx, axis=1) assert sub_I.shape[0] == len(q_chunk_qid) assert sub_I.shape[1] == args.negative_sample+1 if m_batch==0: D_merged_flat = concat_key(all_knn_list_flat, "D", axis=1) I_merged_flat = concat_key(all_knn_list_flat, "I", axis=1) shuffled_ix_flat = np.random.permutation(I_merged_flat.shape[1])[:args.negative_sample + 1] sub_I_flat = np.take(I_merged_flat, shuffled_ix_flat, axis=1) else: sub_I_flat = [0]*len(sub_I) for i, (qid, row, row_flat) in enumerate(zip(q_chunk_qid, sub_I, sub_I_flat)): pos_pid = training_query_positive_id[qid] neg_pids = [x for x in row if x!=pos_pid][:args.negative_sample] if m_batch==0: neg_pids_flat = [x for x in row_flat if x!=pos_pid][:args.negative_sample] f.write("{}\t{}\t{}\n".format(qid, pos_pid, ','.join(str(nid) for nid in neg_pids))) # debug_g.write("{}\t{}\t{}\n".format(qid, pos_pid, ','.join(str(nid) for nid in neg_pids))) # console might crash if encoding isnt set correctly. run export PYTHONIOENCODING=UTF-8 before running the script if that happens. if m_batch==0 and i<100: q1 = get_string(query_embcache, qid, tokenizer) # print(q1) debug_g.write(q1+"\n") p1 = get_string(passage_embcache, pos_pid, tokenizer) debug_g.write(p1+"\n-------------\n") # print(p1) # print("-------------") for nid in neg_pids: ns = get_string(passage_embcache, nid, tokenizer) debug_g.write(ns+"\n") # print(ns) # print("-------------") debug_g.write("-------------\n") for nid in neg_pids_flat: ns = get_string(passage_embcache, nid, tokenizer) debug_g.write(ns+"\n") # print(ns) debug_g.write("===============\n") # print("==============") logger.info("*****Done Constructing ANN Triplet *****") ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f) return dev_ndcg, num_queries_dev