Пример #1
0
    #         'No passages data found. Please specify ctx_file param properly.')

    # questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers,
    #                               args.match)

    if args.out_file:
        save_results(question_ids,
                     top_ids_and_scores, args.out_file)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    add_encoder_params(parser)
    add_tokenizer_params(parser)
    add_cuda_params(parser)

    parser.add_argument('--qa_file', required=True, type=str, default=None,
                        help="Question and answers file of the format: question \\t ['answer1','answer2', ...]")
    parser.add_argument('--ctx_file', required=True, type=str, default=None,
                        help="All passages file in the tsv format: id \\t passage_text \\t title")
    parser.add_argument('--encoded_ctx_file', type=str, default=None,
                        help='Glob path to encoded passages (from generate_dense_embeddings tool)')
    parser.add_argument('--out_file', type=str, default=None,
                        help='output .json file path to write results to ')
    parser.add_argument('--match', type=str, default='string', choices=['regex', 'string'],
                        help="Answer matching logic type")
    parser.add_argument('--n-docs', type=int, default=200,
                        help="Amount of top docs to return")
    parser.add_argument('--validation_workers', type=int, default=16,
                        help="Number of parallel processes to validate results")
Пример #2
0
def setup_dpr(model_file,
              ctx_file,
              encoded_ctx_file,
              hnsw_index=False,
              save_or_load_index=False):
    global retriever
    global all_passages
    global answer_cache
    global answer_cache_path
    parameter_setting = model_file + ctx_file + encoded_ctx_file
    answer_cache_path = hashlib.sha1(
        parameter_setting.encode("utf-8")).hexdigest()
    if os.path.exists(answer_cache_path):
        answer_cache = pickle.load(open(answer_cache_path, 'rb'))
    else:
        answer_cache = {}
    parser = argparse.ArgumentParser()
    add_encoder_params(parser)
    add_tokenizer_params(parser)
    add_cuda_params(parser)

    args = parser.parse_args()
    args.model_file = model_file
    args.ctx_file = ctx_file
    args.encoded_ctx_file = encoded_ctx_file
    args.hnsw_index = hnsw_index
    args.save_or_load_index = save_or_load_index
    args.batch_size = 1  # TODO

    setup_args_gpu(args)
    print_args(args)

    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")

    prefix_len = len("question_model.")
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith("question_model.")
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size, 50000)
    else:
        index = DenseFlatIndexer(vector_size, 50000,
                                 "IVF65536,PQ64")  #IVF65536

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)

    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)

    index_path = "_".join(input_paths[0].split("_")[:-1])
    if args.save_or_load_index and (os.path.exists(index_path) or
                                    os.path.exists(index_path + ".index.dpr")):
        retriever.index.deserialize_from(index_path)
    else:
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index.index_data(input_paths)

        if args.save_or_load_index:
            retriever.index.serialize(index_path)
        # get questions & answers

    all_passages = load_passages(args.ctx_file)