示例#1
0
    def test_shards(self):
        k = 32
        ref_index = faiss.IndexFlatL2(d)

        print('ref search')
        ref_index.add(xb)
        _Dref, Iref = ref_index.search(xq, k)
        print(Iref[:5, :6])

        shard_index = faiss.IndexShards(d)
        shard_index_2 = faiss.IndexShards(d, True, False)

        ni = 3
        for i in range(ni):
            i0 = int(i * nb / ni)
            i1 = int((i + 1) * nb / ni)
            index = faiss.IndexFlatL2(d)
            index.add(xb[i0:i1])
            shard_index.add_shard(index)

            index_2 = faiss.IndexFlatL2(d)
            irm = faiss.IndexIDMap(index_2)
            shard_index_2.add_shard(irm)

        # test parallel add
        shard_index_2.verbose = True
        shard_index_2.add(xb)

        for test_no in range(3):
            with_threads = test_no == 1

            print('shard search test_no = %d' % test_no)
            if with_threads:
                remember_nt = faiss.omp_get_max_threads()
                faiss.omp_set_num_threads(1)
                shard_index.threaded = True
            else:
                shard_index.threaded = False

            if test_no != 2:
                _D, I = shard_index.search(xq, k)
            else:
                _D, I = shard_index_2.search(xq, k)

            print(I[:5, :6])

            if with_threads:
                faiss.omp_set_num_threads(remember_nt)

            ndiff = (I != Iref).sum()

            print('%d / %d differences' % (ndiff, nq * k))
            assert (ndiff < nq * k / 1000.)
示例#2
0
    def test_replicas(self):
        d = 32
        nq = 100
        nb = 200

        (_, xb, xq) = make_binary_dataset(d, 0, nb, nq)

        index_ref = faiss.IndexBinaryFlat(d)
        index_ref.add(xb)

        Dref, Iref = index_ref.search(xq, 10)

        # there is a OpenMP bug in this configuration, so disable threading
        if sys.platform == "darwin" and "Clang 12" in sys.version:
            nthreads = faiss.omp_get_max_threads()
            faiss.omp_set_num_threads(1)
        else:
            nthreads = None

        nrep = 5
        index = faiss.IndexBinaryReplicas()
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            sub_idx.add(xb)
            index.addIndex(sub_idx)

        D, I = index.search(xq, 10)

        self.assertTrue((Dref == D).all())
        self.assertTrue((Iref == I).all())

        index2 = faiss.IndexBinaryReplicas()
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            index2.addIndex(sub_idx)

        index2.add(xb)
        D2, I2 = index2.search(xq, 10)

        if nthreads is not None:
            faiss.omp_set_num_threads(nthreads)

        self.assertTrue((Dref == D2).all())
        self.assertTrue((Iref == I2).all())
示例#3
0
    def test_hnsw(self):
        d = self.xq.shape[1] * 8

        # NOTE(hoss): Ensure the HNSW construction is deterministic.
        nthreads = faiss.omp_get_max_threads()
        faiss.omp_set_num_threads(1)

        index_hnsw_float = faiss.IndexHNSWFlat(d, 16)
        index_hnsw_ref = faiss.IndexBinaryFromFloat(index_hnsw_float)

        index_hnsw_bin = faiss.IndexBinaryHNSW(d, 16)

        index_hnsw_ref.add(self.xb)
        index_hnsw_bin.add(self.xb)

        faiss.omp_set_num_threads(nthreads)

        Dref, Iref = index_hnsw_ref.search(self.xq, 3)
        Dbin, Ibin = index_hnsw_bin.search(self.xq, 3)

        self.assertTrue((Dref == Dbin).all())
示例#4
0
    def parse(self):
        args = self.parser.parse_args()
        self.check_arguments(args)

        args.input_arguments = copy.deepcopy(args)

        args.nranks, args.distributed = distributed.init(args.rank)

        args.nthreads = int(
            max(os.cpu_count(), faiss.omp_get_max_threads()) * 0.8)
        args.nthreads = max(1, args.nthreads // args.nranks)

        if args.nranks > 1:
            print_message(
                f"#> Restricting number of threads for FAISS to {args.nthreads} per process",
                condition=(args.rank == 0))
            faiss.omp_set_num_threads(args.nthreads)

        Run.init(args.rank, args.root, args.experiment, args.run)
        Run._log_args(args)
        Run.info(args.input_arguments.__dict__, '\n')

        return args
示例#5
0
    # find operating points for this index
    opi = params.explore(index, xq, crit)

    print "[%.3f s] result operating points:" % (time.time() - t0)
    opi.display()

    # update best operating points so far
    op.merge_with(opi, index_key + " ")

    op_per_key.append((index_key, opi))

    if True:
        # graphical output (to tmp/ subdirectory)

        fig = pyplot.figure(figsize=(12, 9))
        pyplot.xlabel("1-recall at 1")
        pyplot.ylabel("search time (ms/query, %d threads)" %
                      faiss.omp_get_max_threads())
        pyplot.gca().set_yscale('log')
        pyplot.grid()
        for i2, opi2 in op_per_key:
            plot_OperatingPoints(opi2, crit.nq, label=i2, marker='o')
        # plot_OperatingPoints(op, crit.nq, label = 'best', marker = 'o', color = 'r')
        pyplot.legend(loc=2)
        fig.savefig('tmp/demo_auto_tune.png')

print "[%.3f s] final result:" % (time.time() - t0)

op.display()
示例#6
0
    # find operating points for this index
    opi = params.explore(index, xq, crit)

    print "[%.3f s] result operating points:" % (time.time() - t0)
    opi.display()

    # update best operating points so far
    op.merge_with(opi, index_key + " ")

    op_per_key.append((index_key, opi))

    if True:
        # graphical output (to tmp/ subdirectory)

        fig = pyplot.figure(figsize=(12, 9))
        pyplot.xlabel("1-recall at 1")
        pyplot.ylabel("search time (ms/query, %d threads)" % faiss.omp_get_max_threads())
        pyplot.gca().set_yscale('log')
        pyplot.grid()
        for i2, opi2 in op_per_key:
            plot_OperatingPoints(opi2, crit.nq, label = i2, marker = 'o')
        # plot_OperatingPoints(op, crit.nq, label = 'best', marker = 'o', color = 'r')
        pyplot.legend(loc=2)
        fig.savefig('tmp/demo_auto_tune.png')


print "[%.3f s] final result:" % (time.time() - t0)

op.display()
示例#7
0
print("precomputed tables size:", precomputed_table_size)


#############################################################
# Index is ready
#############################################################

xq = sanitize(ds.get_queries())
gt = ds.get_groundtruth(k=args.k)
assert gt.shape[1] == args.k, pdb.set_trace()

if args.searchthreads != -1:
    print("Setting nb of threads to", args.searchthreads)
    faiss.omp_set_num_threads(args.searchthreads)
else:
    print("nb search threads: ", faiss.omp_get_max_threads())

ps = faiss.ParameterSpace()
ps.initialize(index)

parametersets = args.searchparams



if args.inter:
    header = (
        '%-40s     inter@%3d time(ms/q)   nb distances #runs' %
        ("parameters", args.k)
    )
else:
示例#8
0
    # find operating points for this index
    opi = params.explore(index, xq, crit)

    print "[%.3f s] result operating points:" % (time.time() - t0)
    opi.display()

    # update best operating points so far
    op.merge_with(opi, index_key + " ")

    op_per_key.append((index_key, opi))

    if True:
        # graphical output (to tmp/ subdirectory)

        fig = pyplot.figure(figsize=(12, 9))
        pyplot.xlabel("1-recall at 1")
        pyplot.ylabel("search time (ms/query, %d threads)" % faiss.omp_get_max_threads())
        pyplot.gca().set_yscale('log')
        pyplot.grid()
        for i2, opi2 in op_per_key:
            plot_OperatingPoints(opi2, crit.nq, label = i2, marker = 'o')
        # plot_OperatingPoints(op, crit.nq, label = 'best', marker = 'o', color = 'r')
        pyplot.legend(loc=2)
        fig.savefig('tmp/demo_auto_tune.png')


print "[%.3f s] final result:" % (time.time() - t0)

op.display()
示例#9
0
    def test_shards(self):
        k = 32
        ref_index = faiss.IndexFlatL2(d)

        print('ref search')
        ref_index.add(xb)
        _Dref, Iref = ref_index.search(xq, k)
        print(Iref[:5, :6])

        # there is a OpenMP bug in this configuration, so disable threading
        if sys.platform == "darwin" and "Clang 12" in sys.version:
            nthreads = faiss.omp_get_max_threads()
            faiss.omp_set_num_threads(1)
        else:
            nthreads = None

        shard_index = faiss.IndexShards(d)
        shard_index_2 = faiss.IndexShards(d, True, False)

        ni = 3
        for i in range(ni):
            i0 = int(i * nb / ni)
            i1 = int((i + 1) * nb / ni)
            index = faiss.IndexFlatL2(d)
            index.add(xb[i0:i1])
            shard_index.add_shard(index)

            index_2 = faiss.IndexFlatL2(d)
            irm = faiss.IndexIDMap(index_2)
            shard_index_2.add_shard(irm)

        # test parallel add
        shard_index_2.verbose = True
        shard_index_2.add(xb)

        for test_no in range(3):
            with_threads = test_no == 1

            print('shard search test_no = %d' % test_no)
            if with_threads:
                remember_nt = faiss.omp_get_max_threads()
                faiss.omp_set_num_threads(1)
                shard_index.threaded = True
            else:
                shard_index.threaded = False

            if test_no != 2:
                _D, I = shard_index.search(xq, k)
            else:
                _D, I = shard_index_2.search(xq, k)

            print(I[:5, :6])

            if with_threads:
                faiss.omp_set_num_threads(remember_nt)

            ndiff = (I != Iref).sum()

            print('%d / %d differences' % (ndiff, nq * k))
            assert(ndiff < nq * k / 1000.)

        if nthreads is not None:
            faiss.omp_set_num_threads(nthreads)
def generate_new_ann(args, output_num, checkpoint_path, training_query_positive_id, dev_query_positive_id, query_embcache, passage_embcache):
    while True:
        try:
            config, tokenizer, model = load_model(args, checkpoint_path)
            break
        except:
            time.sleep(60)
            print("retry loading model")

    passage_embcache.change_seed(output_num)
    query_embcache.change_seed(output_num)
    model.eval()
    logger.info("***** inference of passages *****")
    num_batches_per_gpu = len(passage_embcache)//(args.per_gpu_eval_batch_size*dist.get_world_size())

    args.corpus_divider = math.max(math.min(args.corpus_divider, 1.0), 0.0)
    # only run embedding inference for half of the passages to speed up process
    passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_", passage_embcache, is_query_inference = False, end_batch=num_batches_per_gpu*args.corpus_divider)
    logger.info("***** Done passage inference *****")
    pid2ix = {v:k for k, v in enumerate(passage_embedding2id)}

    # build index partition on each process
    dim = passage_embedding.shape[1]
    logger.info('passage embedding shape: ' + str(passage_embedding.shape))
    top_k = args.topk_training 
    faiss.omp_set_num_threads(32//dist.get_world_size())
    print(faiss.omp_get_max_threads())

    if args.flat_index:
        cpu_index = faiss.IndexFlatIP(dim)
    else:
        cpu_index = faiss.index_factory(dim, "IVF8192,Flat")
        cpu_index.train(passage_embedding)

    cpu_index.add(passage_embedding)
    logger.info("***** Done training Index *****")
    cpu_index.nprobe = 50
    logger.info("***** Done building ANN Index *****")
    dist.barrier()
    
    if args.flat_index:
        flat_index = cpu_index
    else:
        flat_index = faiss.IndexFlatIP(dim)
        flat_index.add(passage_embedding)
        
    logger.info("**** Done building flat index *****")

    dev_ndcg, num_queries_dev = 0.0, 0
    train_data_output_path = os.path.join(args.output_dir, "ann_training_data_" + str(output_num))
    debug_output_path = os.path.join(args.output_dir, "ann_debug_"+ str(output_num))

    with open(train_data_output_path, 'w') as f, open(debug_output_path, "w", encoding="utf-8") as debug_g:
        chunk_factor = args.ann_chunk_factor
        if chunk_factor <= 0:
            chunk_factor = 1
        num_batches_per_gpu = len(query_embcache)//(args.per_gpu_eval_batch_size*dist.get_world_size())
        batches_per_chunk = num_batches_per_gpu // chunk_factor
        end_idx = batches_per_chunk
        print("End idx:", end_idx)

        inference_dataset = StreamingDataset(query_embcache, GetProcessingFn(args, query=True))
        inference_dataloader = DataLoader(inference_dataset, batch_size=args.per_gpu_eval_batch_size)
        out_train_list = []
        for m_batch, batch in tqdm(enumerate(inference_dataloader), desc="Inferencing", disable=args.local_rank not in [-1, 0], position=0):
            if m_batch>end_idx:
                break
            qids = batch[3].detach().numpy() #[#B]
            batch = tuple(t.to(args.device) for t in batch)
            with torch.no_grad():
                inputs = {"input_ids": batch[0].long(), "attention_mask": batch[1].long()}
                embs = model.module.query_emb(**inputs)
            embs = embs.detach().cpu().numpy()
            # take only queries with positive passage, then collect
            pos_idx = [i for i,qid in enumerate(qids) if qid in training_query_positive_id]
            query_embedding = embs[pos_idx]
            q_chunk_qid = qids[pos_idx]

            tmp_obj = {"emb": query_embedding, "id": q_chunk_qid}
            objs = all_gather(tmp_obj)
            query_embedding = concat_key(objs, "emb", axis=0)
            q_chunk_qid = concat_key(objs, "id", axis=0)
            if m_batch==0:
                print(query_embedding.shape, q_chunk_qid.shape)
            D, I = cpu_index.search(query_embedding, top_k)
            I = passage_embedding2id[I]

            if m_batch%100==0:
                logger.info(f"***** Done querying ANN Index chunk {m_batch}*****")
            knn_pkl = {"D": D, "I": I}
            all_knn_list = all_gather(knn_pkl)
            del knn_pkl

            if m_batch==0:
                # we only do flat_index search to debug
                Df, If = flat_index.search(query_embedding, top_k)
                If = passage_embedding2id[If]

                knn_pkl_flat = {"D": Df, "I": If}
                all_knn_list_flat = all_gather(knn_pkl_flat)
                del knn_pkl_flat

            if is_first_worker():
                D_merged = concat_key(all_knn_list, "D", axis=1)
                I_merged = concat_key(all_knn_list, "I", axis=1)

                # idx = np.argsort(D_merged, axis=1)[:, ::-1]
                if not args.ann_measure_topk_mrr:
                    shuffled_ix = np.random.permutation(I_merged.shape[1])[:args.negative_sample + 1]
                    sub_I = np.take(I_merged, shuffled_ix, axis=1) 
                else:
                    top_idx = np.argsort(D_merged, axis=1)[:args.negative_sample + 1]
                    sub_I = np.take_along_axis(I_merged, top_idx, axis=1)
                assert sub_I.shape[0] == len(q_chunk_qid)
                assert sub_I.shape[1] == args.negative_sample+1

                if m_batch==0:
                    D_merged_flat = concat_key(all_knn_list_flat, "D", axis=1)
                    I_merged_flat = concat_key(all_knn_list_flat, "I", axis=1)
                    shuffled_ix_flat = np.random.permutation(I_merged_flat.shape[1])[:args.negative_sample + 1]
                    sub_I_flat = np.take(I_merged_flat, shuffled_ix_flat, axis=1)
                else:
                    sub_I_flat = [0]*len(sub_I)

                for i, (qid, row, row_flat) in enumerate(zip(q_chunk_qid, sub_I, sub_I_flat)):
                    pos_pid = training_query_positive_id[qid]
                    neg_pids = [x for x in row if x!=pos_pid][:args.negative_sample]
                    if m_batch==0:
                        neg_pids_flat = [x for x in row_flat if x!=pos_pid][:args.negative_sample]
                    f.write("{}\t{}\t{}\n".format(qid, pos_pid, ','.join(str(nid) for nid in neg_pids)))
                    # debug_g.write("{}\t{}\t{}\n".format(qid, pos_pid, ','.join(str(nid) for nid in neg_pids)))
                    # console might crash if encoding isnt set correctly. run export PYTHONIOENCODING=UTF-8 before running the script if that happens.
                    if m_batch==0 and i<100:
                        q1 = get_string(query_embcache, qid, tokenizer)
                        # print(q1)
                        debug_g.write(q1+"\n")
                        p1 = get_string(passage_embcache, pos_pid, tokenizer)
                        debug_g.write(p1+"\n-------------\n")
                        # print(p1)
                        # print("-------------")
                        for nid in neg_pids:
                            ns = get_string(passage_embcache, nid, tokenizer)
                            debug_g.write(ns+"\n")
                            # print(ns)
                        # print("-------------")
                        debug_g.write("-------------\n")
                        for nid in neg_pids_flat:
                            ns = get_string(passage_embcache, nid, tokenizer)
                            debug_g.write(ns+"\n")
                            # print(ns)
                        debug_g.write("===============\n")
                        # print("==============")

    logger.info("*****Done Constructing ANN Triplet *****")
    ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num))
    with open(ndcg_output_path, 'w') as f:
        json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f)

    return dev_ndcg, num_queries_dev