# 'No passages data found. Please specify ctx_file param properly.') # questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers, # args.match) if args.out_file: save_results(question_ids, top_ids_and_scores, args.out_file) if __name__ == '__main__': parser = argparse.ArgumentParser() add_encoder_params(parser) add_tokenizer_params(parser) add_cuda_params(parser) parser.add_argument('--qa_file', required=True, type=str, default=None, help="Question and answers file of the format: question \\t ['answer1','answer2', ...]") parser.add_argument('--ctx_file', required=True, type=str, default=None, help="All passages file in the tsv format: id \\t passage_text \\t title") parser.add_argument('--encoded_ctx_file', type=str, default=None, help='Glob path to encoded passages (from generate_dense_embeddings tool)') parser.add_argument('--out_file', type=str, default=None, help='output .json file path to write results to ') parser.add_argument('--match', type=str, default='string', choices=['regex', 'string'], help="Answer matching logic type") parser.add_argument('--n-docs', type=int, default=200, help="Amount of top docs to return") parser.add_argument('--validation_workers', type=int, default=16, help="Number of parallel processes to validate results")
def setup_dpr(model_file, ctx_file, encoded_ctx_file, hnsw_index=False, save_or_load_index=False): global retriever global all_passages global answer_cache global answer_cache_path parameter_setting = model_file + ctx_file + encoded_ctx_file answer_cache_path = hashlib.sha1( parameter_setting.encode("utf-8")).hexdigest() if os.path.exists(answer_cache_path): answer_cache = pickle.load(open(answer_cache_path, 'rb')) else: answer_cache = {} parser = argparse.ArgumentParser() add_encoder_params(parser) add_tokenizer_params(parser) add_cuda_params(parser) args = parser.parse_args() args.model_file = model_file args.ctx_file = ctx_file args.encoded_ctx_file = encoded_ctx_file args.hnsw_index = hnsw_index args.save_or_load_index = save_or_load_index args.batch_size = 1 # TODO setup_args_gpu(args) print_args(args) saved_state = load_states_from_checkpoint(args.model_file) set_encoder_params_from_state(saved_state.encoder_params, args) tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) encoder = encoder.question_model encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, args.local_rank, args.fp16) encoder.eval() # load weights from the model file model_to_load = get_model_obj(encoder) logger.info("Loading saved model state ...") prefix_len = len("question_model.") question_encoder_state = { key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith("question_model.") } model_to_load.load_state_dict(question_encoder_state) vector_size = model_to_load.get_out_size() logger.info("Encoder vector_size=%d", vector_size) if args.hnsw_index: index = DenseHNSWFlatIndexer(vector_size, 50000) else: index = DenseFlatIndexer(vector_size, 50000, "IVF65536,PQ64") #IVF65536 retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index) # index all passages ctx_files_pattern = args.encoded_ctx_file input_paths = glob.glob(ctx_files_pattern) index_path = "_".join(input_paths[0].split("_")[:-1]) if args.save_or_load_index and (os.path.exists(index_path) or os.path.exists(index_path + ".index.dpr")): retriever.index.deserialize_from(index_path) else: logger.info("Reading all passages data from files: %s", input_paths) retriever.index.index_data(input_paths) if args.save_or_load_index: retriever.index.serialize(index_path) # get questions & answers all_passages = load_passages(args.ctx_file)