def check_dense(index): # dummy queries; there is no explicit validation... # we just try to initialize the and make sure there are no exceptions dummy_queries = QueryEncoder.load_encoded_queries( 'tct_colbert-msmarco-passage-dev-subset') print('\n') for entry in index: print(f'# Validating "{entry}"...') if "bpr" in entry: BinaryDenseSearcher.from_prebuilt_index(entry, dummy_queries) else: SimpleDenseSearcher.from_prebuilt_index(entry, dummy_queries) print('\n')
def ance(qid, query, docs, index_path): searcher = SimpleDenseSearcher(index_path, 'castorini/ance-msmarco-doc-maxp') hits = searcher.search(query, 1000) output = [] n = 1 seen_docids = {} for i in range(0, len(hits)): if hits[i].docid in seen_docids: continue output.append([qid, hits[i].docid, n]) n = n + 1 seen_docids[hits[i].docid] = 1 return output
def test_dpr_encode_as_faiss(self): index_dir = f'{self.pyserini_root}/temp_index' self.temp_folders.append(index_dir) cmd1 = f"python -m pyserini.encode input --corpus {self.corpus_path} \ --fields text \ output --embeddings {index_dir} --to-faiss \ encoder --encoder facebook/dpr-ctx_encoder-multiset-base \ --fields text \ --batch 4 \ --device cpu" _ = os.system(cmd1) searcher = SimpleDenseSearcher( index_dir, 'facebook/dpr-question_encoder-multiset-base') q_emb, hit = searcher.search( "What is the solution of separable closed queueing networks?", k=1, return_vector=True) self.assertEqual(hit[0].docid, 'CACM-2445') self.assertAlmostEqual(hit[0].vectors[0], -6.88267112e-01, places=4) self.assertEqual(searcher.num_docs, 3204)
def load_ranker(args): if args.sparse and args.dense: sparse_searcher = SimpleSearcher(args.sparse_index_path) sparse_searcher.set_bm25(args.k, args.b) sparse_searcher.set_rm3(args.expansion_terms, args.expansion_documents, args.original_query_weight) encoder = TCTColBERTQueryEncoder('castorini/tct_colbert-msmarco') dense_searcher = SimpleDenseSearcher(args.dense_index_path, encoder) hsearcher = HybridSearcher(dense_searcher, sparse_searcher) elif args.sparse: sparse_searcher = SimpleSearcher(args.sparse_index_path) sparse_searcher.set_bm25(args.k, args.b) sparse_searcher.set_rm3(args.expansion_terms, args.expansion_documents, args.original_query_weight) return sparse_searcher elif args.dense: encoder = TCTColBERTQueryEncoder('castorini/tct_colbert-msmarco') dense_searcher = SimpleDenseSearcher(args.dense_index_path, encoder) return dense_searcher else: print( "Choose a valid ranking function sparse(BM25), dense(vector) or a combination of the two" ) exit(0)
def do_model(self, arg): if arg == "tct": encoder = TctColBertQueryEncoder("castorini/tct_colbert-msmarco") index = "msmarco-passage-tct_colbert-hnsw" elif arg == "ance": encoder = AnceQueryEncoder("castorini/ance-msmarco-passage") index = "msmarco-passage-ance-bf" else: print( f'Model "{arg}" is invalid. Model should be one of [tct, ance].' ) return self.dsearcher = SimpleDenseSearcher.from_prebuilt_index( index, encoder) self.hsearcher = HybridSearcher(self.dsearcher, self.ssearcher) print(f'setting model = {arg}')
# invalid topics name if topics == {}: print(f'Topic {args.run.topics} Not Found') exit() query_encoder = init_query_encoder(args.dense.encoder, args.run.topics, args.dense.encoded_queries, args.dense.device) if not query_encoder: print(f'No encoded queries for topic {args.run.topics}') exit() if os.path.exists(args.dense.index): # create searcher from index directory dsearcher = SimpleDenseSearcher(args.dense.index, query_encoder) else: # create searcher from prebuilt index name dsearcher = SimpleDenseSearcher.from_prebuilt_index(args.dense.index, query_encoder) if not dsearcher: exit() if os.path.exists(args.sparse.index): # create searcher from index directory ssearcher = SimpleSearcher(args.sparse.index) else: # create searcher from prebuilt index name ssearcher = SimpleSearcher.from_prebuilt_index(args.sparse.index) if not ssearcher:
topics = get_topics(args.topics) # invalid topics name if topics == {}: print(f'Topic {args.topics} Not Found') exit() query_encoder = init_query_encoder(args.encoder, args.topics, args.encoded_queries, args.device) if not query_encoder: print(f'No encoded queries for topic {args.topics}') exit() if os.path.exists(args.index): # create searcher from index directory searcher = SimpleDenseSearcher(args.index, query_encoder) else: # create searcher from prebuilt index name searcher = SimpleDenseSearcher.from_prebuilt_index( args.index, query_encoder) if not searcher: exit() # build output path output_path = args.output print(f'Running {args.topics} topics, saving to {output_path}...') tag = 'Faiss' order = None
parser.add_argument('--query', type=str, required=False, default='', help="user query appended to predictions") # index corpus, device parser.add_argument('--reader-model', type=str, required=False, help="Reader model name or path") parser.add_argument('--reader-device', type=str, required=False, default='cuda:0', help="Device to run inference on") args = parser.parse_args() # check arguments arg_check(args, parser) print("Init QA models") if args.type == 'openbook': if args.qa_reader == 'dpr': reader = DprReader(args.reader_model, device=args.reader_device) if args.retriever_model: retriever = SimpleDenseSearcher(args.retriever_index, DprQueryEncoder(args.retriever_model)) else: retriever = SimpleSearcher.from_prebuilt_index(args.retriever_corpus) corpus = SimpleSearcher.from_prebuilt_index(args.retriever_corpus) obqa = OpenBookQA(reader, retriever, corpus) # run a warm up question obqa.predict('what is lobster roll') while True: question = input('Enter a question: ') answer = obqa.predict(question) answer_text = answer["answer"] answer_context = answer["context"]["text"] print(f"Answer:\t {answer_text}") print(f"Context:\t {answer_context}") elif args.qa_reader == 'fid': reader = FidReader(model_name=args.reader_model, device=args.reader_device)
# create query encoder from query embedding directory query_encoder = TCTColBERTQueryEncoder(args.encoded_queries) else: # create query encoder from pre encoded query name query_encoder = TCTColBERTQueryEncoder.load_encoded_queries( args.encoded_queries) else: query_encoder = TCTColBERTQueryEncoder(encoder_dir=args.encoder, device=args.device) if not query_encoder: exit() if os.path.exists(args.index): # create searcher from index directory searcher = SimpleDenseSearcher(args.index, query_encoder) else: # create searcher from prebuilt index name searcher = SimpleDenseSearcher.from_prebuilt_index(args.index, query_encoder) if not searcher: exit() # invalid topics name if topics == {}: print(f'Topic {args.topics} Not Found') exit() # build output path output_path = args.output
class DPRDemo(cmd.Cmd): nq_dev_topics = list(search.get_topics('dpr-nq-dev').values()) trivia_dev_topics = list(search.get_topics('dpr-trivia-dev').values()) ssearcher = SimpleSearcher.from_prebuilt_index('wikipedia-dpr') searcher = ssearcher encoder = DprQueryEncoder("facebook/dpr-question_encoder-multiset-base") index = 'wikipedia-dpr-multi-bf' dsearcher = SimpleDenseSearcher.from_prebuilt_index( index, encoder ) hsearcher = HybridSearcher(dsearcher, ssearcher) k = 10 prompt = '>>> ' def precmd(self, line): if line[0] == '/': line = line[1:] return line def do_help(self, arg): print(f'/help : returns this message') print(f'/k [NUM] : sets k (number of hits to return) to [NUM]') print(f'/mode [MODE] : sets retriver type to [MODE] (one of sparse, dense, hybrid)') print(f'/random [COLLECTION]: returns results for a random question from the dev subset [COLLECTION] (one of nq, trivia).') def do_k(self, arg): print(f'setting k = {int(arg)}') self.k = int(arg) def do_mode(self, arg): if arg == "sparse": self.searcher = self.ssearcher elif arg == "dense": self.searcher = self.dsearcher elif arg == "hybrid": self.searcher = self.hsearcher else: print( f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].') return print(f'setting retriver = {arg}') def do_random(self, arg): if arg == "nq": topics = self.nq_dev_topics elif arg == "trivia": topics = self.trivia_dev_topics else: print( f'Collection "{arg}" is invalid. Collection should be one of [nq, trivia].') return q = random.choice(topics)['title'] print(f'question: {q}') self.default(q) def do_EOF(self, line): return True def default(self, q): hits = self.searcher.search(q, self.k) for i in range(0, len(hits)): raw_doc = None if isinstance(self.searcher, SimpleSearcher): raw_doc = hits[i].raw else: doc = self.searcher.doc(hits[i].docid) if doc: raw_doc = doc.raw() jsondoc = json.loads(raw_doc) print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')
parser = argparse.ArgumentParser(description='Interactive QA') commands = parser.add_subparsers(title='sub-commands') dense_parser = commands.add_parser('reader') define_reader_args(dense_parser) sparse_parser = commands.add_parser('retriever') define_retriever_args(sparse_parser) args = parse_args(parser, commands) print("Init QA models") reader = DprReader(args.reader.model, device=args.reader.device) if args.retriever.model: retriever = SimpleDenseSearcher(args.retriever.index, DprQueryEncoder(args.retriever.model)) else: retriever = SimpleSearcher.from_prebuilt_index(args.retriever.corpus) corpus = SimpleSearcher.from_prebuilt_index(args.retriever.corpus) obqa = OpenBookQA(reader, retriever, corpus) # run a warm up question obqa.predict('what is lobster roll') while True: question = input('Please enter a question: ') answer = obqa.predict(question) answer_text = answer["answer"] answer_context = answer["context"]["text"] print(f"ANSWER:\t {answer_text}") print(f"CONTEXT:\t {answer_context}")