Beispiel #1
0
def get_retriever(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.question_model
    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')

    prefix_len = len('question_model.')
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith('question_model.')
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info('Encoder vector_size=%d', vector_size)

    index_buffer_sz = args.index_buffer
    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size)
        index_buffer_sz = -1  # encode all at once
    else:
        index = DenseFlatIndexer(vector_size)

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index,
                               args.device)

    # index all passages
    if len(args.encoded_ctx_file) > 0:
        ctx_files_pattern = args.encoded_ctx_file
        input_paths = glob.glob(ctx_files_pattern)

        index_path = "_".join(input_paths[0].split("_")[:-1])
        if args.save_or_load_index and os.path.exists(index_path):
            retriever.index.deserialize(index_path)
        else:
            logger.info('Reading all passages data from files: %s',
                        input_paths)
            retriever.index_encoded_data(input_paths,
                                         buffer_size=index_buffer_sz)
            if args.save_or_load_index:
                retriever.index.serialize(index_path)
        # get questions & answers
    return retriever
Beispiel #2
0
    def __init__(self, name, **config):
        super().__init__(name)

        self.args = argparse.Namespace(**config)
        saved_state = load_states_from_checkpoint(self.args.model_file)
        set_encoder_params_from_state(saved_state.encoder_params, self.args)
        tensorizer, encoder, _ = init_biencoder_components(
            self.args.encoder_model_type, self.args, inference_only=True)
        encoder = encoder.question_model
        encoder, _ = setup_for_distributed_mode(
            encoder,
            None,
            self.args.device,
            self.args.n_gpu,
            self.args.local_rank,
            self.args.fp16,
        )
        encoder.eval()

        # load weights from the model file
        model_to_load = get_model_obj(encoder)

        prefix_len = len("question_model.")
        question_encoder_state = {
            key[prefix_len:]: value
            for (key, value) in saved_state.model_dict.items()
            if key.startswith("question_model.")
        }
        model_to_load.load_state_dict(question_encoder_state)
        vector_size = model_to_load.get_out_size()

        index_buffer_sz = self.args.index_buffer
        if self.args.hnsw_index:
            index = DenseHNSWFlatIndexer(vector_size)
            index.deserialize_from(self.args.hnsw_index_path)
        else:
            index = DenseFlatIndexer(vector_size)

        self.retriever = DenseRetriever(encoder, self.args.batch_size,
                                        tensorizer, index)

        # index all passages
        ctx_files_pattern = self.args.encoded_ctx_file
        input_paths = glob.glob(ctx_files_pattern)

        if not self.args.hnsw_index:
            self.retriever.index_encoded_data(input_paths,
                                              buffer_size=index_buffer_sz)

        # not needed for now
        self.all_passages = load_passages(self.args.ctx_file)

        self.KILT_mapping = None
        if self.args.KILT_mapping:
            self.KILT_mapping = pickle.load(open(self.args.KILT_mapping, "rb"))
Beispiel #3
0
def main(args):
    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(
        args.encoder_model_type, args, inference_only=True)

    encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu,
                                            args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info('Loading saved model state ...')

    prefix_len = len('question_model.')
    question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                              key.startswith('question_model.')}
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info('Encoder vector_size=%d', vector_size)

    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size, args.index_buffer)
    else:
        index = DenseFlatIndexer(vector_size, args.index_buffer)

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)

    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)

    index_path = "_".join(input_paths[0].split("_")[:-1])
    if args.save_or_load_index and (os.path.exists(index_path) or os.path.exists(index_path + ".index.dpr")):
        retriever.index.deserialize_from(index_path)
    else:
        logger.info('Reading all passages data from files: %s', input_paths)
        retriever.index.index_data(input_paths)
        if args.save_or_load_index:
            retriever.index.serialize(index_path)
    # get questions & answers
    questions = []
    question_ids = []

    for ds_item in parse_qa_csv_file(args.qa_file):
        question_id, question = ds_item
        question_ids.append(question_id)
        questions.append(question)

    questions_tensor = retriever.generate_question_vectors(questions)

    # get top k results
    top_ids_and_scores = retriever.get_top_docs(
        questions_tensor.numpy(), args.n_docs)

    # all_passages = load_passages(args.ctx_file)

    # if len(all_passages) == 0:
    #     raise RuntimeError(
    #         'No passages data found. Please specify ctx_file param properly.')

    # questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers,
    #                               args.match)

    if args.out_file:
        save_results(question_ids,
                     top_ids_and_scores, args.out_file)
Beispiel #4
0
    print('Creating docid file...')
    not_found = set()

    with open(os.path.join(args.input_dir, 'kilt_w100_title.tsv'), 'r') as f, \
            open(os.path.join(args.input_dir, 'docid'), 'w') as outp:
        tsv = csv.reader(f, delimiter='\t')
        next(tsv)  # skip headers
        for row in tqdm(tsv, mininterval=10.0, maxinterval=20.0):
            i = row[0]
            title = row[2]
            if title not in KILT_mapping:
                not_found.add(f"{title}#{i}")
                wikipedia_id = 'N/A'
            else:
                wikipedia_id = KILT_mapping[title]
            docid = f"{wikipedia_id}#{i}" if args.passage else wikipedia_id
            _ = outp.write(f'{docid}\n')

    print("Done writing docid file!")
    print(f'Some documents did not have a docid in the mapping: {not_found}')

    print('Creating index file...')
    ctx_files_pattern = f'{args.input_dir}/kilt_passages_2048_0.pkl'
    input_paths = glob.glob(ctx_files_pattern)

    vector_size = 768
    index = DenseFlatIndexer(vector_size)
    index.index_data(input_paths)
    faiss.write_index(index, f'{args.output_dir}/index')
    print('Done writing index file!')
Beispiel #5
0
        encoder.eval()

        # load weights from the model file
        model_to_load = get_model_obj(encoder)
        prefix_len = len('question_model.')
        question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                                key.startswith('question_model.')}
        model_to_load.load_state_dict(question_encoder_state)
        vector_size = model_to_load.get_out_size()

        index_buffer_sz = args.index_buffer
        if args.hnsw_index:
            index = DenseHNSWFlatIndexer(vector_size)
            index_buffer_sz = -1  # encode all at once
        else:
            index = DenseFlatIndexer(vector_size)

        retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)
        retriever.index.deserialize_from(args.dense_index_path)

        questions_tensor = retriever.generate_question_vectors(questions)
        top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), args.n_docs)


    all_passages = load_passages(args.db_path)

    retrieval_file = "tmp_{}.json".format(str(np.random.randint(0, 100000)).zfill(6))
    questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores,
                                  1, args.match)

    save_results(all_passages,
def setup_dpr(model_file,
              ctx_file,
              encoded_ctx_file,
              hnsw_index=False,
              save_or_load_index=False):
    global retriever
    global all_passages
    global answer_cache
    global answer_cache_path
    parameter_setting = model_file + ctx_file + encoded_ctx_file
    answer_cache_path = hashlib.sha1(
        parameter_setting.encode("utf-8")).hexdigest()
    if os.path.exists(answer_cache_path):
        answer_cache = pickle.load(open(answer_cache_path, 'rb'))
    else:
        answer_cache = {}
    parser = argparse.ArgumentParser()
    add_encoder_params(parser)
    add_tokenizer_params(parser)
    add_cuda_params(parser)

    args = parser.parse_args()
    args.model_file = model_file
    args.ctx_file = ctx_file
    args.encoded_ctx_file = encoded_ctx_file
    args.hnsw_index = hnsw_index
    args.save_or_load_index = save_or_load_index
    args.batch_size = 1  # TODO

    setup_args_gpu(args)
    print_args(args)

    saved_state = load_states_from_checkpoint(args.model_file)
    set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type,
                                                       args,
                                                       inference_only=True)

    encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device,
                                            args.n_gpu, args.local_rank,
                                            args.fp16)
    encoder.eval()

    # load weights from the model file
    model_to_load = get_model_obj(encoder)
    logger.info("Loading saved model state ...")

    prefix_len = len("question_model.")
    question_encoder_state = {
        key[prefix_len:]: value
        for (key, value) in saved_state.model_dict.items()
        if key.startswith("question_model.")
    }
    model_to_load.load_state_dict(question_encoder_state)
    vector_size = model_to_load.get_out_size()
    logger.info("Encoder vector_size=%d", vector_size)

    if args.hnsw_index:
        index = DenseHNSWFlatIndexer(vector_size, 50000)
    else:
        index = DenseFlatIndexer(vector_size, 50000,
                                 "IVF65536,PQ64")  #IVF65536

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index)

    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)

    index_path = "_".join(input_paths[0].split("_")[:-1])
    if args.save_or_load_index and (os.path.exists(index_path) or
                                    os.path.exists(index_path + ".index.dpr")):
        retriever.index.deserialize_from(index_path)
    else:
        logger.info("Reading all passages data from files: %s", input_paths)
        retriever.index.index_data(input_paths)

        if args.save_or_load_index:
            retriever.index.serialize(index_path)
        # get questions & answers

    all_passages = load_passages(args.ctx_file)
Beispiel #7
0
def main(args):
    questions = []
    question_answers = []
    for i, ds_item in enumerate(parse_qa_csv_file(args.qa_file)):
        #if i == 10:
        #    break
        question, answers = ds_item
        questions.append(question)
        question_answers.append(answers)
    if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert':
        saved_state = load_states_from_checkpoint(args.model_file)
        set_encoder_params_from_state(saved_state.encoder_params, args)

    tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True)

    if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert':
        encoder = encoder.question_model

    encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu,
                                            args.local_rank,
                                            args.fp16)
    encoder.eval()

    if not args.encoder_model_type == 'hf_attention' and not args.encoder_model_type == 'colbert':
        # load weights from the model file
        model_to_load = get_model_obj(encoder)
        logger.info('Loading saved model state ...')

        prefix_len = len('question_model.')
        question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
                                  key.startswith('question_model.')}
        model_to_load.load_state_dict(question_encoder_state)
        vector_size = model_to_load.get_out_size()
        logger.info('Encoder vector_size=%d', vector_size)
    else:
        vector_size = 16

    index_buffer_sz = args.index_buffer

    if args.index_type == 'hnsw':
        index = DenseHNSWFlatIndexer(vector_size, index_buffer_sz)
        index_buffer_sz = -1  # encode all at once
    elif args.index_type == 'custom':
        index = CustomIndexer(vector_size, index_buffer_sz)
    else:
        index = DenseFlatIndexer(vector_size)

    retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index, args.encoder_model_type=='colbert')


    # index all passages
    ctx_files_pattern = args.encoded_ctx_file
    input_paths = glob.glob(ctx_files_pattern)
    #logger.info('Reading all passages data from files: %s', input_paths)
    #memmap = retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz, memmap=args.encoder_model_type=='colbert')
    print(input_paths)
    index_path = "_".join(input_paths[0].split("_")[:-1])
    memmap_path = 'memmap.npy'
    print(args.save_or_load_index, os.path.exists(index_path), index_path)
    if args.save_or_load_index and os.path.exists(index_path+'.index.dpr'):
    #if False:
        retriever.index.deserialize_from(index_path)
        if args.encoder_model_type=='colbert':
            memmap = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=(21015324, 250, 16))
        else:
            memmap = None
    else:
        logger.info('Reading all passages data from files: %s', input_paths)
        if args.encoder_model_type=='colbert':
            memmap = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=(21015324, 250, 16))
        else:
            memmap = None
        retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz, memmap=memmap)
        if args.save_or_load_index:
            retriever.index.serialize(index_path)
    # get questions & answers



    questions_tensor = retriever.generate_question_vectors(questions)

    # get top k results
    top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), args.n_docs, is_colbert=args.encoder_model_type=='colbert')

    with open('approx_scores.pkl', 'wb') as f:
        pickle.dump(top_ids_and_scores, f)

    retriever.index.index.reset()
    if args.encoder_model_type=='colbert':
        logger.info('Colbert score') 
        top_ids_and_scores_colbert = retriever.colbert_search(questions_tensor.numpy(), memmap, top_ids_and_scores, args.n_docs)

    all_passages = load_passages(args.ctx_file)

    if len(all_passages) == 0:
        raise RuntimeError('No passages data found. Please specify ctx_file param properly.')


    with open('colbert_scores.pkl', 'wb') as f:
        pickle.dump(top_ids_and_scores_colbert, f)

    

    questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers,
                                  args.match)
    if args.encoder_model_type=='colbert':
        questions_doc_hits_colbert = validate(all_passages, question_answers, top_ids_and_scores_colbert, args.validation_workers,
                                  args.match)
        

    if args.out_file:
        save_results(all_passages, questions, question_answers, top_ids_and_scores, questions_doc_hits, args.out_file)
        if args.encoder_model_type=='colbert':
            save_results(all_passages, questions, question_answers, top_ids_and_scores_colbert, questions_doc_hits_colbert, args.out_file+'colbert')