예제 #1
0
    def __init__(self, args, need_vocab=True):
        self.tfidf_path=args.tfidf_path
        self.ranker = retriever.get_class('tfidf')(tfidf_path=self.tfidf_path)
        self.first_para_only = False
        self.db = DocDB(args.wiki_db_path)
        self.L = 300
        self.first_para_only = False

        if need_vocab:
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            btokenizer = BasicTokenizer()
            self.tokenize = lambda c, t_c: tokenizer.tokenize(c)
            self.btokenize  = btokenizer.tokenize

        self.keyword2title = defaultdict(list)
        self.cache = {}
예제 #2
0
def get_count_matrix(args, db, db_opts):
    """Form a sparse word to document count matrix (inverted index).

    M[i, j] = # times word i appears in document j.
    """
    # Map doc_ids to indexes
    global DOC2IDX
    db_class = retriever.get_class(db)
    with db_class(**db_opts) as doc_db:
        doc_ids = doc_db.get_doc_ids()
    DOC2IDX = {doc_id: i for i, doc_id in enumerate(doc_ids)}

    # Setup worker pool
    tok_class = tokenizers.get_class(args.tokenizer)
    workers = ProcessPool(
        args.num_workers,
        initializer=init,
        initargs=(tok_class, db_class, db_opts)
    )

    # Compute the count matrix in steps (to keep in memory)
    logger.info('Mapping...')
    row, col, data = [], [], []
    step = max(int(len(doc_ids) / 10), 1)
    batches = [doc_ids[i:i + step] for i in range(0, len(doc_ids), step)]
    _count = partial(count, args.ngram, args.hash_size)
    for i, batch in enumerate(batches):
        logger.info('-' * 25 + 'Batch %d/%d' % (i + 1, len(batches)) + '-' * 25)
        for b_row, b_col, b_data in workers.imap_unordered(_count, batch):
            row.extend(b_row)
            col.extend(b_col)
            data.extend(b_data)
    workers.close()
    workers.join()

    logger.info('Creating sparse matrix...')
    count_matrix = sp.csr_matrix(
        (data, (row, col)), shape=(args.hash_size, len(doc_ids))
    )
    count_matrix.sum_duplicates()
    return count_matrix, (DOC2IDX, doc_ids)
예제 #3
0
    '''questions = []
    with open(args.qa_file) as ifile:
        reader = csv.reader(ifile, delimiter='\t')
        for row in reader:
            questions.append(row[0])'''
    questions = []
    question_answers = []
    for ds_item in parse_qa_csv_file(args.qa_file):
        question, answers = ds_item
        questions.append(question)
        question_answers.append(answers)

    if args.retrieval_type=="tfidf":
        import drqa_retriever as retriever
        ranker = retriever.get_class('tfidf')(tfidf_path=args.tfidf_path)
        top_ids_and_scores = []
        for question in questions:
            psg_ids, scores = ranker.closest_docs(question, args.n_docs)
            top_ids_and_scores.append((psg_ids, scores))
    else:
        from dpr.models import init_biencoder_components
        from dpr.utils.data_utils import Tensorizer
        from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint
        from dpr.indexer.faiss_indexers import DenseIndexer, DenseHNSWFlatIndexer, DenseFlatIndexer
        from dense_retriever import DenseRetriever

        saved_state = load_states_from_checkpoint(args.dpr_model_file)
        set_encoder_params_from_state(saved_state.encoder_params, args)
        tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True)
        encoder = encoder.question_model