示例#1
0
 def __init__(self, tfidf_path=None, strict=True):
     """
     Args:
         tfidf_path: path to saved model file
         strict: fail on empty queries or continue (and return empty result)
     """
     # Load from disk
     tfidf_path = tfidf_path or DEFAULTS['tfidf_path']
     logger.info('Loading retriever %s' % tfidf_path)
     matrix, metadata = utils.load_sparse_csr(tfidf_path)
     self.doc_mat = matrix
     self.ngrams = metadata['ngram']
     self.hash_size = metadata['hash_size']
     self.tokenizer = tokenizers.get_class(metadata['tokenizer'])()
     self.doc_freqs = metadata['doc_freqs'].squeeze()
     self.doc_dict = metadata['doc_dict']
     self.num_docs = len(self.doc_dict[0])
     self.strict = strict
示例#2
0
def process_dataset(data, tokenizer, workers=None):
    """Iterate processing (tokenize, parse, etc) dataset multithreaded."""
    tokenizer_class = tokenizers.get_class(tokenizer)
    make_pool = partial(Pool, workers, initializer=init)
    workers = make_pool(initargs=(tokenizer_class, {'annotators': {'lemma'}}))
    q_tokens = workers.map(tokenize, data['questions'])
    workers.close()
    workers.join()

    workers = make_pool(initargs=(tokenizer_class, {
        'annotators': {'lemma', 'pos', 'ner'}
    }))
    c_tokens = workers.map(tokenize, data['contexts'])
    workers.close()
    workers.join()

    for idx in range(len(data['qids'])):
        question = q_tokens[idx]['words']
        qlemma = q_tokens[idx]['lemma']
        document = c_tokens[data['qid2cid'][idx]]['words']
        offsets = c_tokens[data['qid2cid'][idx]]['offsets']
        lemma = c_tokens[data['qid2cid'][idx]]['lemma']
        pos = c_tokens[data['qid2cid'][idx]]['pos']
        ner = c_tokens[data['qid2cid'][idx]]['ner']
        ans_tokens = []
        if len(data['answers']) > 0:
            for ans in data['answers'][idx]:
                found = find_answer(offsets, ans['answer_start'],
                                    ans['answer_start'] + len(ans['text']))
                if found:
                    ans_tokens.append(found)
        yield {
            'id': data['qids'][idx],
            'question': question,
            'document': document,
            'offsets': offsets,
            'answers': ans_tokens,
            'qlemma': qlemma,
            'lemma': lemma,
            'pos': pos,
            'ner': ner,
        }
示例#3
0
def main(args):
    # Read query data
    start = time.time()
    logger.info('Reading query data {}'.format(args.query_data))
    questions = []
    answers = []
    for line in open(args.query_data):
        qa_pair = json.loads(line)
        question = qa_pair['question']
        answer = qa_pair['answer']
        questions.append(question)
        answers.append(answer)

    # Load candidates
    candidates = None
    if args.candidate_file:
        logger.info('Loading candidates from %s' % args.candidate_file)
        candidates = set()
        with open(args.candidate_file) as f:
            for line in f:
                line = utils.normalize(line.strip()).lower()
                candidates.add(line)
        logger.info('Loaded %d candidates.' % len(candidates))

    # get the closest docs for each question.
    logger.info('Initializing pipeline...')
    pipeline = QAPipeline(retriever_path=args.retriever_path,
                          db_path=args.db_path,
                          ranker_path=args.ranker_path,
                          reader_path=args.reader_path,
                          module_batch_size=args.module_batch_size,
                          module_max_loaders=args.module_max_loaders,
                          module_cuda=args.cuda,
                          fixed_candidates=candidates)

    # Batcify questions and feed for prediction
    batches = [questions[i: i + args.predict_batch_size]
        for i in range(0, len(questions), args.predict_batch_size)]
    batches_targets = [answers[i: i + args.predict_batch_size]
        for i in range(0, len(answers), args.predict_batch_size)]

    # Predict and record results
    logger.info('Predicting...' if not args.train else 'Training...')
    if args.train:
        best_loss = 9999999
        train_loss = utils.AverageMeter()
        for e_idx in range(args.train_epoch):
            logger.info('Epoch {}'.format(e_idx+1))
            for i, (batch, target) in enumerate(zip(batches, batches_targets)):
                logger.info(
                    '-' * 25 + ' Batch %d/%d ' % (i + 1, len(batches)) + '-' * 25
                )
                loss = pipeline.update(batch, target, 
                                       n_docs=args.n_docs, n_pars=args.n_pars)
                train_loss.update(loss, 1)
                
            if train_loss.avg < best_loss:
                logger.info('Best loss = %.3f' % train_loss.avg)
                pipeline.ranker.save(args.ranker_file)
                best_loss = train_loss.avg

        logger.info('Training done')
        exit()
    else:
        closest_pars = []
        with open(args.pred_file, 'w') as pred_f:
            for i, (batch, target) in enumerate(zip(batches, batches_targets)):
                logger.info(
                    '-' * 25 + ' Batch %d/%d ' % (i + 1, len(batches)) + '-' * 25
                )
                with torch.no_grad():
                    closest_par, predictions = pipeline.predict(batch,
                                                                n_docs=args.n_docs,
                                                                n_pars=args.n_pars)
                    closest_pars += closest_par

                for p in predictions:
                    pred_f.write(json.dumps(p) + '\n')

            answers_pars = zip(answers, closest_pars)

    # define processes
    tok_class = tokenizers.get_class(args.tokenizer)
    tok_opts = {}
    db_class = DocDB
    db_opts = {'db_path': args.db_path}
    processes = ProcessPool(
        processes=args.data_workers,
        initializer=init,
        initargs=(tok_class, tok_opts, db_class, db_opts)
    )

    # compute the scores for each pair, and print the statistics
    logger.info('Retrieving and computing scores...')
    get_score_partial = partial(get_score, match=args.match, use_text=True)
    scores = processes.map(get_score_partial, answers_pars)

    filename = os.path.basename(args.query_data)
    stats = (
        "\n" + "-" * 50 + "\n" +
        "{filename}\n" +
        "Examples:\t\t\t{total}\n" +
        "Matches in top {k}:\t\t{m}\n" +
        "Match % in top {k}:\t\t{p:2.2f}\n" +
        "Total time:\t\t\t{t:2.4f} (s)\n"
    ).format(
        filename=filename,
        total=len(scores),
        k=args.n_docs,
        m=sum(scores),
        p=(sum(scores) / len(scores) * 100),
        t=time.time() - start,
    )

    print(stats)
示例#4
0
        questions.append(question)
        answers.append(answer)

    # get the closest docs for each question.
    logger.info('Initializing ranker...')
    ranker = TfidfDocRanker(tfidf_path=args.model)

    logger.info('Ranking...')
    print('processing query', questions[0])
    closest_docs = ranker.batch_closest_docs(
        questions, k=args.n_docs, num_workers=args.num_workers
    )
    answers_docs = zip(answers, closest_docs)

    # define processes
    tok_class = tokenizers.get_class(args.tokenizer)
    tok_opts = {}
    db_class = DocDB
    db_opts = {'db_path': args.doc_db}
    processes = ProcessPool(
        processes=args.num_workers,
        initializer=init,
        initargs=(tok_class, tok_opts, db_class, db_opts)
    )

    # compute the scores for each pair, and print the statistics
    logger.info('Retrieving and computing scores...')
    get_score_partial = partial(get_score, match=args.match)
    scores = processes.map(get_score_partial, answers_docs)

    filename = os.path.basename(args.dataset)