Ejemplo n.º 1
0
def retrieve_documents_for_generated_questions(config, input_file, verbose):
    retriever = load_retriever(config["retriever"])
    generated_questions = load_jsonl(input_file)
    logger.info("Running Filterer Retriever...")
    generated_questions_with_retrieved_docs = retriever.retrieve_documents(
        generated_questions)
    return generated_questions_with_retrieved_docs
Ejemplo n.º 2
0
def generate_questions(config, input_file, verbose):
    question_generator = load_question_generator(config)
    passage_answer_pairs = load_jsonl(input_file)
    logger.info("Running Question Generation...")
    annotations = question_generator.generate_questions_from_passage_answer_pairs(
        passage_answer_pairs, disable_tqdm=not verbose)
    return annotations
Ejemplo n.º 3
0
def generate_answers_for_generated_questions_with_retrieved_docs(
        config, input_file, verbose):
    reader = load_reader(config["reader"])
    generated_questions_with_retrieved_docs = load_jsonl(input_file)

    logger.info("Running Filterer Reader...")
    results = reader.generate_answers(generated_questions_with_retrieved_docs)
    return results
Ejemplo n.º 4
0
def run_predictions(qas_to_rerank_file, output_file, model_name_or_path,
                    batch_size, fp16, top_k):
    qas_to_rerank = load_jsonl(qas_to_rerank_file)
    reranker_model, reranker_tokenizer = load_reranker(model_name_or_path)

    predictions = predict(reranker_model,
                          reranker_tokenizer,
                          qas_to_rerank,
                          bsz=batch_size,
                          fp16=fp16,
                          top_k=top_k)
    dump_jsonl(predictions, output_file)
Ejemplo n.º 5
0
 def _add_passage_metadata(questions_fi, passage_scores):
     generated_qas = load_jsonl(questions_fi)
     qas_dict = defaultdict(list)
     for qas in generated_qas:
         question, answer, passage_id = qas["question"], qas["answer"], qas[
             "passage_id"]
         metadata = {
             "passage_id": passage_id,
             "ps_score": passage_scores[passage_id],
             'answer': answer
         }
         metadata.update(qas["metadata"])
         qas_dict[question].append((answer, metadata))
     return qas_dict
Ejemplo n.º 6
0
def embed_job(qas_to_embed_path, model_name_or_path, output_file_name, n_jobs,
              job_num, batch_size, fp16, memory_friendly_parsing):
    os.makedirs(os.path.dirname(output_file_name), exist_ok=True)

    qas_to_embed = load_jsonl(qas_to_embed_path,
                              memory_friendly=memory_friendly_parsing)
    chunk_size = math.ceil(len(qas_to_embed) / n_jobs)

    qas_to_embed_this_job = qas_to_embed[job_num * chunk_size:(job_num + 1) *
                                         chunk_size]
    logger.info(
        f'Embedding Job {job_num}: Embedding {len(qas_to_embed)} inputs in {int(len(qas_to_embed) / batch_size)} batches:'
    )

    model, tokenizer = load_retriever(model_name_or_path)
    mat = embed(model,
                tokenizer,
                qas_to_embed_this_job,
                bsz=batch_size,
                fp16=fp16)
    torch.save(mat.half(), output_file_name + f'.{job_num}')
Ejemplo n.º 7
0
def load_passages(path):
    try:
        return load_jsonl(path)
    except:
        return load_dpr_tsv(path)
Ejemplo n.º 8
0
        if not dont_print:
            print(
                f'{k}: {100 * sum(scores) / len(scores):0.1f}% \n({sum(scores)} / {len(scores)})'
            )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--predictions',
        type=str,
        help=
        "path to retrieval results to eval, in PAQ's retrieved results jsonl format"
    )
    parser.add_argument('--references',
                        type=str,
                        help="path to gold answers, in jsonl format")
    parser.add_argument('--hits_at_k',
                        type=str,
                        help='comma separated list of K to eval hits@k for',
                        default="1,10,50")
    args = parser.parse_args()

    refs = load_jsonl(args.references)
    preds = load_jsonl(args.predictions)
    assert len(refs) == len(
        preds), "number of references doesnt match number of predictions"

    hits_at_k = sorted([int(k) for k in args.hits_at_k.split(',')])
    eval_retriever(refs, preds, hits_at_k)