コード例 #1
0
def check_dense(index):
    # dummy queries; there is no explicit validation...
    # we just try to initialize the and make sure there are no exceptions
    dummy_queries = QueryEncoder.load_encoded_queries(
        'tct_colbert-msmarco-passage-dev-subset')
    print('\n')
    for entry in index:
        print(f'# Validating "{entry}"...')
        if "bpr" in entry:
            BinaryDenseSearcher.from_prebuilt_index(entry, dummy_queries)
        else:
            SimpleDenseSearcher.from_prebuilt_index(entry, dummy_queries)
        print('\n')
コード例 #2
0
def ance(qid, query, docs, index_path):
    searcher = SimpleDenseSearcher(index_path,  'castorini/ance-msmarco-doc-maxp')
    hits = searcher.search(query, 1000)

    output = []
    n = 1
    seen_docids = {}
    for i in range(0, len(hits)):
        if hits[i].docid in seen_docids:
            continue
        output.append([qid, hits[i].docid, n])
        n = n + 1
        seen_docids[hits[i].docid] = 1

    return output
コード例 #3
0
ファイル: test_encode.py プロジェクト: castorini/pyserini
    def test_dpr_encode_as_faiss(self):
        index_dir = f'{self.pyserini_root}/temp_index'
        self.temp_folders.append(index_dir)
        cmd1 = f"python -m pyserini.encode input   --corpus {self.corpus_path} \
                                  --fields text \
                          output  --embeddings {index_dir} --to-faiss \
                          encoder --encoder facebook/dpr-ctx_encoder-multiset-base \
                                  --fields text \
                                  --batch 4 \
                                  --device cpu"

        _ = os.system(cmd1)
        searcher = SimpleDenseSearcher(
            index_dir, 'facebook/dpr-question_encoder-multiset-base')
        q_emb, hit = searcher.search(
            "What is the solution of separable closed queueing networks?",
            k=1,
            return_vector=True)
        self.assertEqual(hit[0].docid, 'CACM-2445')
        self.assertAlmostEqual(hit[0].vectors[0], -6.88267112e-01, places=4)
        self.assertEqual(searcher.num_docs, 3204)
コード例 #4
0
ファイル: search.py プロジェクト: spacemanidol/CS510IR
def load_ranker(args):
    if args.sparse and args.dense:
        sparse_searcher = SimpleSearcher(args.sparse_index_path)
        sparse_searcher.set_bm25(args.k, args.b)
        sparse_searcher.set_rm3(args.expansion_terms, args.expansion_documents,
                                args.original_query_weight)
        encoder = TCTColBERTQueryEncoder('castorini/tct_colbert-msmarco')
        dense_searcher = SimpleDenseSearcher(args.dense_index_path, encoder)
        hsearcher = HybridSearcher(dense_searcher, sparse_searcher)
    elif args.sparse:
        sparse_searcher = SimpleSearcher(args.sparse_index_path)
        sparse_searcher.set_bm25(args.k, args.b)
        sparse_searcher.set_rm3(args.expansion_terms, args.expansion_documents,
                                args.original_query_weight)
        return sparse_searcher
    elif args.dense:
        encoder = TCTColBERTQueryEncoder('castorini/tct_colbert-msmarco')
        dense_searcher = SimpleDenseSearcher(args.dense_index_path, encoder)
        return dense_searcher
    else:
        print(
            "Choose a valid ranking function sparse(BM25), dense(vector) or a combination of the two"
        )
        exit(0)
コード例 #5
0
    def do_model(self, arg):
        if arg == "tct":
            encoder = TctColBertQueryEncoder("castorini/tct_colbert-msmarco")
            index = "msmarco-passage-tct_colbert-hnsw"
        elif arg == "ance":
            encoder = AnceQueryEncoder("castorini/ance-msmarco-passage")
            index = "msmarco-passage-ance-bf"
        else:
            print(
                f'Model "{arg}" is invalid. Model should be one of [tct, ance].'
            )
            return

        self.dsearcher = SimpleDenseSearcher.from_prebuilt_index(
            index, encoder)
        self.hsearcher = HybridSearcher(self.dsearcher, self.ssearcher)
        print(f'setting model = {arg}')
コード例 #6
0
    # invalid topics name
    if topics == {}:
        print(f'Topic {args.run.topics} Not Found')
        exit()

    query_encoder = init_query_encoder(args.dense.encoder,
                                       args.run.topics,
                                       args.dense.encoded_queries,
                                       args.dense.device)
    if not query_encoder:
        print(f'No encoded queries for topic {args.run.topics}')
        exit()

    if os.path.exists(args.dense.index):
        # create searcher from index directory
        dsearcher = SimpleDenseSearcher(args.dense.index, query_encoder)
    else:
        # create searcher from prebuilt index name
        dsearcher = SimpleDenseSearcher.from_prebuilt_index(args.dense.index, query_encoder)

    if not dsearcher:
        exit()

    if os.path.exists(args.sparse.index):
        # create searcher from index directory
        ssearcher = SimpleSearcher(args.sparse.index)
    else:
        # create searcher from prebuilt index name
        ssearcher = SimpleSearcher.from_prebuilt_index(args.sparse.index)

    if not ssearcher:
コード例 #7
0
        topics = get_topics(args.topics)

    # invalid topics name
    if topics == {}:
        print(f'Topic {args.topics} Not Found')
        exit()

    query_encoder = init_query_encoder(args.encoder, args.topics,
                                       args.encoded_queries, args.device)
    if not query_encoder:
        print(f'No encoded queries for topic {args.topics}')
        exit()

    if os.path.exists(args.index):
        # create searcher from index directory
        searcher = SimpleDenseSearcher(args.index, query_encoder)
    else:
        # create searcher from prebuilt index name
        searcher = SimpleDenseSearcher.from_prebuilt_index(
            args.index, query_encoder)

    if not searcher:
        exit()

    # build output path
    output_path = args.output

    print(f'Running {args.topics} topics, saving to {output_path}...')
    tag = 'Faiss'

    order = None
コード例 #8
0
    parser.add_argument('--query', type=str, required=False, default='', help="user query appended to predictions")
    # index corpus, device
    parser.add_argument('--reader-model', type=str, required=False, help="Reader model name or path")
    parser.add_argument('--reader-device', type=str, required=False, default='cuda:0', help="Device to run inference on")

    args = parser.parse_args()

    # check arguments
    arg_check(args, parser)

    print("Init QA models")
    if args.type == 'openbook':
        if args.qa_reader == 'dpr':
            reader = DprReader(args.reader_model, device=args.reader_device)
            if args.retriever_model:
                retriever = SimpleDenseSearcher(args.retriever_index, DprQueryEncoder(args.retriever_model))
            else:
                retriever = SimpleSearcher.from_prebuilt_index(args.retriever_corpus)
            corpus = SimpleSearcher.from_prebuilt_index(args.retriever_corpus)
            obqa = OpenBookQA(reader, retriever, corpus)
            # run a warm up question
            obqa.predict('what is lobster roll')
            while True:
                question = input('Enter a question: ')
                answer = obqa.predict(question)
                answer_text = answer["answer"]
                answer_context = answer["context"]["text"]
                print(f"Answer:\t {answer_text}")
                print(f"Context:\t {answer_context}")
        elif args.qa_reader == 'fid':
            reader = FidReader(model_name=args.reader_model, device=args.reader_device)
コード例 #9
0
        # create query encoder from query embedding directory
        query_encoder = TCTColBERTQueryEncoder(args.encoded_queries)
    else:
        # create query encoder from pre encoded query name
        query_encoder = TCTColBERTQueryEncoder.load_encoded_queries(
            args.encoded_queries)
else:
    query_encoder = TCTColBERTQueryEncoder(encoder_dir=args.encoder,
                                           device=args.device)

if not query_encoder:
    exit()

if os.path.exists(args.index):
    # create searcher from index directory
    searcher = SimpleDenseSearcher(args.index, query_encoder)
else:
    # create searcher from prebuilt index name
    searcher = SimpleDenseSearcher.from_prebuilt_index(args.index,
                                                       query_encoder)

if not searcher:
    exit()

# invalid topics name
if topics == {}:
    print(f'Topic {args.topics} Not Found')
    exit()

# build output path
output_path = args.output
コード例 #10
0
class DPRDemo(cmd.Cmd):
    nq_dev_topics = list(search.get_topics('dpr-nq-dev').values())
    trivia_dev_topics = list(search.get_topics('dpr-trivia-dev').values())

    ssearcher = SimpleSearcher.from_prebuilt_index('wikipedia-dpr')
    searcher = ssearcher

    encoder = DprQueryEncoder("facebook/dpr-question_encoder-multiset-base")
    index = 'wikipedia-dpr-multi-bf'
    dsearcher = SimpleDenseSearcher.from_prebuilt_index(
        index,
        encoder
    )
    hsearcher = HybridSearcher(dsearcher, ssearcher)

    k = 10
    prompt = '>>> '

    def precmd(self, line):
        if line[0] == '/':
            line = line[1:]
        return line

    def do_help(self, arg):
        print(f'/help    : returns this message')
        print(f'/k [NUM] : sets k (number of hits to return) to [NUM]')
        print(f'/mode [MODE] : sets retriver type to [MODE] (one of sparse, dense, hybrid)')
        print(f'/random [COLLECTION]: returns results for a random question from the dev subset [COLLECTION] (one of nq, trivia).')

    def do_k(self, arg):
        print(f'setting k = {int(arg)}')
        self.k = int(arg)

    def do_mode(self, arg):
        if arg == "sparse":
            self.searcher = self.ssearcher
        elif arg == "dense":
            self.searcher = self.dsearcher
        elif arg == "hybrid":
            self.searcher = self.hsearcher
        else:
            print(
                f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].')
            return
        print(f'setting retriver = {arg}')

    def do_random(self, arg):
        if arg == "nq":
            topics = self.nq_dev_topics
        elif arg == "trivia":
            topics = self.trivia_dev_topics
        else:
            print(
                f'Collection "{arg}" is invalid. Collection should be one of [nq, trivia].')
            return
        q = random.choice(topics)['title']
        print(f'question: {q}')
        self.default(q)

    def do_EOF(self, line):
        return True

    def default(self, q):
        hits = self.searcher.search(q, self.k)

        for i in range(0, len(hits)):
            raw_doc = None
            if isinstance(self.searcher, SimpleSearcher):
                raw_doc = hits[i].raw
            else:
                doc = self.searcher.doc(hits[i].docid)
                if doc:
                    raw_doc = doc.raw()
            jsondoc = json.loads(raw_doc)
            print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')
コード例 #11
0
    parser = argparse.ArgumentParser(description='Interactive QA')

    commands = parser.add_subparsers(title='sub-commands')

    dense_parser = commands.add_parser('reader')
    define_reader_args(dense_parser)

    sparse_parser = commands.add_parser('retriever')
    define_retriever_args(sparse_parser)

    args = parse_args(parser, commands)

    print("Init QA models")
    reader = DprReader(args.reader.model, device=args.reader.device)
    if args.retriever.model:
        retriever = SimpleDenseSearcher(args.retriever.index,
                                        DprQueryEncoder(args.retriever.model))
    else:
        retriever = SimpleSearcher.from_prebuilt_index(args.retriever.corpus)
    corpus = SimpleSearcher.from_prebuilt_index(args.retriever.corpus)
    obqa = OpenBookQA(reader, retriever, corpus)

    # run a warm up question
    obqa.predict('what is lobster roll')
    while True:
        question = input('Please enter a question: ')
        answer = obqa.predict(question)
        answer_text = answer["answer"]
        answer_context = answer["context"]["text"]
        print(f"ANSWER:\t {answer_text}")
        print(f"CONTEXT:\t {answer_context}")