Esempio n. 1
0
def main(args):
    questions = []
    question_answers = []

    for ds_item in parse_qa_csv_file(args.qa_file):
        question, answers = ds_item
        questions.append(question)
        question_answers.append(answers)

    top_ids_and_scores = []
    for question in questions:
        psg_ids, scores = ranker.closest_docs(question, args.n_docs)
        top_ids_and_scores.append((psg_ids, scores))

    all_passages = load_passages(args.db_path)

    if len(all_passages) == 0:
        raise RuntimeError(
            'No passages data found. Please specify ctx_file param properly.')

    questions_doc_hits = validate(all_passages, question_answers,
                                  top_ids_and_scores, args.validation_workers,
                                  args.match)

    if args.out_file:
        save_results(all_passages, questions, question_answers,
                     top_ids_and_scores, questions_doc_hits, args.out_file)
Esempio n. 2
0
    def __init__(self, name, **config):
        super().__init__(name)

        self.args = argparse.Namespace(**config)
        saved_state = load_states_from_checkpoint(self.args.model_file)
        set_encoder_params_from_state(saved_state.encoder_params, self.args)
        tensorizer, encoder, _ = init_biencoder_components(
            self.args.encoder_model_type, self.args, inference_only=True)
        encoder = encoder.question_model
        encoder, _ = setup_for_distributed_mode(
            encoder,
            None,
            self.args.device,
            self.args.n_gpu,
            self.args.local_rank,
            self.args.fp16,
        )
        encoder.eval()

        # load weights from the model file
        model_to_load = get_model_obj(encoder)

        prefix_len = len("question_model.")
        question_encoder_state = {
            key[prefix_len:]: value
            for (key, value) in saved_state.model_dict.items()
            if key.startswith("question_model.")
        }
        model_to_load.load_state_dict(question_encoder_state)
        vector_size = model_to_load.get_out_size()

        index_buffer_sz = self.args.index_buffer
        if self.args.hnsw_index:
            index = DenseHNSWFlatIndexer(vector_size)
            index.deserialize_from(self.args.hnsw_index_path)
        else:
            index = DenseFlatIndexer(vector_size)

        self.retriever = DenseRetriever(encoder, self.args.batch_size,
                                        tensorizer, index)

        # index all passages
        ctx_files_pattern = self.args.encoded_ctx_file
        input_paths = glob.glob(ctx_files_pattern)

        if not self.args.hnsw_index:
            self.retriever.index_encoded_data(input_paths,
                                              buffer_size=index_buffer_sz)

        # not needed for now
        self.all_passages = load_passages(self.args.ctx_file)

        self.KILT_mapping = None
        if self.args.KILT_mapping:
            self.KILT_mapping = pickle.load(open(self.args.KILT_mapping, "rb"))
Esempio n. 3
0
        index_buffer_sz = args.index_buffer
        if args.hnsw_index:
            index = DenseHNSWFlatIndexer(vector_size)
            index_buffer_sz = -1  # encode all at once
        else:
            index = DenseFlatIndexer(vector_size)

        retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)
        retriever.index.deserialize_from(args.dense_index_path)

        questions_tensor = retriever.generate_question_vectors(questions)
        top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), args.n_docs)


    all_passages = load_passages(args.db_path)

    retrieval_file = "tmp_{}.json".format(str(np.random.randint(0, 100000)).zfill(6))
    questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores,
                                  1, args.match)

    save_results(all_passages,
                 questions,
                 question_answers, #["" for _ in questions],
                 top_ids_and_scores,
                 questions_doc_hits, #[[False for _ in range(args.n_docs)] for _n in questions],
                 retrieval_file)
    setup_args_gpu(args)
    #print_args(args)
    args.dev_file = retrieval_file
#!/usr/bin/env python3

# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

import sys
import numpy as np
import json

sys.path.append("DPR")

from dense_retriever import iterate_encoded_files, load_passages

docs = load_passages(sys.argv[1])
vector_files = sys.argv[2:]
phi = 275.26935  #pre-computed phi for transformation from inner dot product space to euclidean space.

# Write out an empty query document for the question encoder
doc = {"put": "id:query:query::1", "fields": {}}
json.dump(doc, sys.stdout)
sys.stdout.write('\n')

# Write all wikipedia articles
for i, item in enumerate(iterate_encoded_files(vector_files)):
    db_id, doc_vector = item
    norm = (doc_vector**2).sum()
    aux_dim = np.sqrt(phi - norm)
    l2_vector = np.hstack((doc_vector, aux_dim))
    passage_text, title = docs[db_id]
    doc = {
        "put": "id:wiki:wiki::%s" % db_id,
        "fields": {