示例#1
0
def ance(qid, query, docs, index_path):
    searcher = SimpleDenseSearcher(index_path,  'castorini/ance-msmarco-doc-maxp')
    hits = searcher.search(query, 1000)

    output = []
    n = 1
    seen_docids = {}
    for i in range(0, len(hits)):
        if hits[i].docid in seen_docids:
            continue
        output.append([qid, hits[i].docid, n])
        n = n + 1
        seen_docids[hits[i].docid] = 1

    return output
示例#2
0
    def test_dpr_encode_as_faiss(self):
        index_dir = f'{self.pyserini_root}/temp_index'
        self.temp_folders.append(index_dir)
        cmd1 = f"python -m pyserini.encode input   --corpus {self.corpus_path} \
                                  --fields text \
                          output  --embeddings {index_dir} --to-faiss \
                          encoder --encoder facebook/dpr-ctx_encoder-multiset-base \
                                  --fields text \
                                  --batch 4 \
                                  --device cpu"

        _ = os.system(cmd1)
        searcher = SimpleDenseSearcher(
            index_dir, 'facebook/dpr-question_encoder-multiset-base')
        q_emb, hit = searcher.search(
            "What is the solution of separable closed queueing networks?",
            k=1,
            return_vector=True)
        self.assertEqual(hit[0].docid, 'CACM-2445')
        self.assertAlmostEqual(hit[0].vectors[0], -6.88267112e-01, places=4)
        self.assertEqual(searcher.num_docs, 3204)
示例#3
0
    print(f'Running {args.topics} topics, saving to {output_path}...')
    tag = 'Faiss'

    order = None
    if args.topics in QUERY_IDS:
        print(f'Using pre-defined topic order for {args.topics}')
        order = QUERY_IDS[args.topics]

    with open(output_path, 'w') as target_file:
        batch_topics = list()
        batch_topic_ids = list()
        for index, (topic_id, text) in enumerate(
                tqdm(list(query_iterator(topics, order)))):
            if args.batch_size <= 1 and args.threads <= 1:
                hits = searcher.search(text, args.hits)
                results = [(topic_id, hits)]
            else:
                batch_topic_ids.append(str(topic_id))
                batch_topics.append(text)
                if (index + 1) % args.batch_size == 0 or \
                        index == len(topics.keys()) - 1:
                    results = searcher.batch_search(batch_topics,
                                                    batch_topic_ids, args.hits,
                                                    args.threads)
                    results = [(id_, results[id_]) for id_ in batch_topic_ids]
                    batch_topic_ids.clear()
                    batch_topics.clear()
                else:
                    continue
示例#4
0
                topics[topic].get('title').strip() for topic in topic_key_batch
            ]
            hits = searcher.batch_search(topic_batch,
                                         topic_key_batch,
                                         k=args.hits,
                                         threads=args.threads)
            for topic in hits:
                for idx, hit in enumerate(hits[topic]):
                    if args.msmarco:
                        target_file.write(f'{topic}\t{hit.docid}\t{idx + 1}\n')
                    else:
                        target_file.write(
                            f'{topic} Q0 {hit.docid} {idx + 1} {hit.score:.6f} {tag}\n'
                        )
    exit()

with open(output_path, 'w') as target_file:
    for index, topic in enumerate(tqdm(sorted(topics.keys()))):
        search = topics[topic].get('title').strip()
        hits = searcher.search(search, args.hits, threads=args.threads)
        docids = [hit.docid.strip() for hit in hits]
        scores = [hit.score for hit in hits]

        if args.msmarco:
            for i, docid in enumerate(docids):
                target_file.write(f'{topic}\t{docid}\t{i + 1}\n')
        else:
            for i, (docid, score) in enumerate(zip(docids, scores)):
                target_file.write(
                    f'{topic} Q0 {docid} {i + 1} {score:.6f} {tag}\n')