コード例 #1
0
def init_query_encoder(encoder, topics_name, encoded_queries, device):
    encoded_queries_map = {
        'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset',
        'dpr-nq-dev': 'dpr_multi-nq-dev',
        'dpr-nq-test': 'dpr_multi-nq-test',
        'dpr-trivia-dev': 'dpr_multi-trivia-dev',
        'dpr-trivia-test': 'dpr_multi-trivia-test',
        'dpr-wq-test': 'dpr_multi-wq-test',
        'dpr-squad-test': 'dpr_multi-squad-test',
        'dpr-curated-test': 'dpr_multi-curated-test'
    }
    if encoder:
        if 'dpr' in encoder:
            return DprQueryEncoder(encoder_dir=encoder, device=device)
        elif 'tct_colbert' in encoder:
            return TctColBertQueryEncoder(encoder_dir=encoder, device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder, device=device)
        elif 'sentence' in encoder:
            return AutoQueryEncoder(encoder_dir=encoder,
                                    device=device,
                                    pooling='mean',
                                    l2_norm=True)
        else:
            return AutoQueryEncoder(encoder_dir=encoder, device=device)
    if encoded_queries:
        if os.path.exists(encoded_queries):
            return QueryEncoder(encoded_queries)
        return QueryEncoder.load_encoded_queries(encoded_queries)
    if topics_name in encoded_queries_map:
        return QueryEncoder.load_encoded_queries(
            encoded_queries_map[topics_name])
    return None
コード例 #2
0
def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries,
                       device, prefix):
    encoded_queries_map = {
        'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset',
        'dpr-nq-dev': 'dpr_multi-nq-dev',
        'dpr-nq-test': 'dpr_multi-nq-test',
        'dpr-trivia-dev': 'dpr_multi-trivia-dev',
        'dpr-trivia-test': 'dpr_multi-trivia-test',
        'dpr-wq-test': 'dpr_multi-wq-test',
        'dpr-squad-test': 'dpr_multi-squad-test',
        'dpr-curated-test': 'dpr_multi-curated-test'
    }
    if encoder:
        if 'dkrr' in encoder:
            return DkrrDprQueryEncoder(encoder_dir=encoder,
                                       device=device,
                                       prefix=prefix)
        elif 'dpr' in encoder:
            return DprQueryEncoder(encoder_dir=encoder,
                                   tokenizer_name=tokenizer_name,
                                   device=device)
        elif 'bpr' in encoder:
            return BprQueryEncoder(encoder_dir=encoder,
                                   tokenizer_name=tokenizer_name,
                                   device=device)
        elif 'tct_colbert' in encoder:
            return TctColBertQueryEncoder(encoder_dir=encoder,
                                          tokenizer_name=tokenizer_name,
                                          device=device)
        elif 'ance' in encoder:
            return AnceQueryEncoder(encoder_dir=encoder,
                                    tokenizer_name=tokenizer_name,
                                    device=device)
        elif 'sentence' in encoder:
            return AutoQueryEncoder(encoder_dir=encoder,
                                    tokenizer_name=tokenizer_name,
                                    device=device,
                                    pooling='mean',
                                    l2_norm=True)
        else:
            return AutoQueryEncoder(encoder_dir=encoder,
                                    tokenizer_name=tokenizer_name,
                                    device=device)

    if encoded_queries:
        if os.path.exists(encoded_queries):
            if 'bpr' in encoded_queries:
                return BprQueryEncoder(encoded_query_dir=encoded_queries)
            else:
                return QueryEncoder(encoded_queries)
        return QueryEncoder.load_encoded_queries(encoded_queries)

    if topics_name in encoded_queries_map:
        return QueryEncoder.load_encoded_queries(
            encoded_queries_map[topics_name])
    raise ValueError(f'No encoded queries for topic {topics_name}')
コード例 #3
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder', type=str, help='encoder name or path', required=True)
    parser.add_argument('--topics', type=str, help='topic name', required=True)
    parser.add_argument('--output', type=str, help='dir to store query embeddings', required=True)
    parser.add_argument('--device', type=str,
                        help='device cpu or cuda [cuda:0, cuda:1...]', default='cpu', required=False)
    args = parser.parse_args()
    device = args.device
    topics = get_topics(args.topics)

    if not os.path.exists(args.output):
        os.mkdir(args.output)

    if 'dpr' in args.encoder:
        encoder = DprQueryEncoder(encoder_dir=args.encoder, device=device)
    elif 'tct_colbert' in args.encoder:
        encoder = TctColBertQueryEncoder(encoder_dir=args.encoder, device=device)
    elif 'ance' in args.encoder:
        encoder = AnceQueryEncoder(encoder_dir=args.encoder, device=device)
    elif 'sentence' in args.encoder:
        encoder = AutoQueryEncoder(encoder_dir=args.encoder, device=device, pooling='mean', l2_norm=True)
    else:
        encoder = AutoQueryEncoder(encoder_dir=args.encoder, device=device)

    embeddings = {'id': [], 'text': [], 'embedding': []}
    for key in tqdm(topics):
        qid = str(key)
        text = topics[key]['title']
        embeddings['id'].append(qid)
        embeddings['text'].append(text)
コード例 #4
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Compute embeddings for KILT topics')
    parser.add_argument('--topics', required=True)
    parser.add_argument('--output', default="embedding.pkl", help="Name and path to output file.")
    parser.add_argument('--encoder', metavar='path to query encoder checkpoint or encoder name',
                        required=True,
                        help="Path to query encoder pytorch checkpoint or hgf encoder model name")
    parser.add_argument('--tokenizer', metavar='name or path',
                        required=True,
                        help="Path to a hgf tokenizer name or path")
    parser.add_argument('--device', metavar='device to run query encoder', required=False, default='cpu',
                        help="Device to run query encoder, cpu or [cuda:0, cuda:1, ...]")
    args = parser.parse_args()

    query_iterator = get_query_iterator(args.topics, TopicsFormat.KILT)
    query_encoder = DprQueryEncoder(encoder_dir=args.encoder, tokenizer_name=args.tokenizer, device=args.device)

    texts = []
    embeddings = []
    for i, (topic_id, text) in enumerate(tqdm(query_iterator)):
        texts.append(text)
        embeddings.append(query_encoder.encode(text))

    df = pd.DataFrame({
        'text': texts,
        'embedding': embeddings
    })

    df.to_pickle(args.output)
コード例 #5
0
    parser.add_argument('--query', type=str, required=False, default='', help="user query appended to predictions")
    # index corpus, device
    parser.add_argument('--reader-model', type=str, required=False, help="Reader model name or path")
    parser.add_argument('--reader-device', type=str, required=False, default='cuda:0', help="Device to run inference on")

    args = parser.parse_args()

    # check arguments
    arg_check(args, parser)

    print("Init QA models")
    if args.type == 'openbook':
        if args.qa_reader == 'dpr':
            reader = DprReader(args.reader_model, device=args.reader_device)
            if args.retriever_model:
                retriever = SimpleDenseSearcher(args.retriever_index, DprQueryEncoder(args.retriever_model))
            else:
                retriever = SimpleSearcher.from_prebuilt_index(args.retriever_corpus)
            corpus = SimpleSearcher.from_prebuilt_index(args.retriever_corpus)
            obqa = OpenBookQA(reader, retriever, corpus)
            # run a warm up question
            obqa.predict('what is lobster roll')
            while True:
                question = input('Enter a question: ')
                answer = obqa.predict(question)
                answer_text = answer["answer"]
                answer_context = answer["context"]["text"]
                print(f"Answer:\t {answer_text}")
                print(f"Context:\t {answer_context}")
        elif args.qa_reader == 'fid':
            reader = FidReader(model_name=args.reader_model, device=args.reader_device)
コード例 #6
0
class DPRDemo(cmd.Cmd):
    nq_dev_topics = list(search.get_topics('dpr-nq-dev').values())
    trivia_dev_topics = list(search.get_topics('dpr-trivia-dev').values())

    ssearcher = SimpleSearcher.from_prebuilt_index('wikipedia-dpr')
    searcher = ssearcher

    encoder = DprQueryEncoder("facebook/dpr-question_encoder-multiset-base")
    index = 'wikipedia-dpr-multi-bf'
    dsearcher = SimpleDenseSearcher.from_prebuilt_index(
        index,
        encoder
    )
    hsearcher = HybridSearcher(dsearcher, ssearcher)

    k = 10
    prompt = '>>> '

    def precmd(self, line):
        if line[0] == '/':
            line = line[1:]
        return line

    def do_help(self, arg):
        print(f'/help    : returns this message')
        print(f'/k [NUM] : sets k (number of hits to return) to [NUM]')
        print(f'/mode [MODE] : sets retriver type to [MODE] (one of sparse, dense, hybrid)')
        print(f'/random [COLLECTION]: returns results for a random question from the dev subset [COLLECTION] (one of nq, trivia).')

    def do_k(self, arg):
        print(f'setting k = {int(arg)}')
        self.k = int(arg)

    def do_mode(self, arg):
        if arg == "sparse":
            self.searcher = self.ssearcher
        elif arg == "dense":
            self.searcher = self.dsearcher
        elif arg == "hybrid":
            self.searcher = self.hsearcher
        else:
            print(
                f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].')
            return
        print(f'setting retriver = {arg}')

    def do_random(self, arg):
        if arg == "nq":
            topics = self.nq_dev_topics
        elif arg == "trivia":
            topics = self.trivia_dev_topics
        else:
            print(
                f'Collection "{arg}" is invalid. Collection should be one of [nq, trivia].')
            return
        q = random.choice(topics)['title']
        print(f'question: {q}')
        self.default(q)

    def do_EOF(self, line):
        return True

    def default(self, q):
        hits = self.searcher.search(q, self.k)

        for i in range(0, len(hits)):
            raw_doc = None
            if isinstance(self.searcher, SimpleSearcher):
                raw_doc = hits[i].raw
            else:
                doc = self.searcher.doc(hits[i].docid)
                if doc:
                    raw_doc = doc.raw()
            jsondoc = json.loads(raw_doc)
            print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')
コード例 #7
0
    commands = parser.add_subparsers(title='sub-commands')

    dense_parser = commands.add_parser('reader')
    define_reader_args(dense_parser)

    sparse_parser = commands.add_parser('retriever')
    define_retriever_args(sparse_parser)

    args = parse_args(parser, commands)

    print("Init QA models")
    reader = DprReader(args.reader.model, device=args.reader.device)
    if args.retriever.model:
        retriever = SimpleDenseSearcher(args.retriever.index,
                                        DprQueryEncoder(args.retriever.model))
    else:
        retriever = SimpleSearcher.from_prebuilt_index(args.retriever.corpus)
    corpus = SimpleSearcher.from_prebuilt_index(args.retriever.corpus)
    obqa = OpenBookQA(reader, retriever, corpus)

    # run a warm up question
    obqa.predict('what is lobster roll')
    while True:
        question = input('Please enter a question: ')
        answer = obqa.predict(question)
        answer_text = answer["answer"]
        answer_context = answer["context"]["text"]
        print(f"ANSWER:\t {answer_text}")
        print(f"CONTEXT:\t {answer_context}")