コード例 #1
0
ファイル: eval.py プロジェクト: rahular/ellipsis-baselines
def evaluate(dataset_file, prediction_file, regex=False):
    print("-" * 50)
    print("Dataset: %s" % dataset_file)
    print("Predictions: %s" % prediction_file)

    answers = []
    for line in open(args.dataset):
        data = json.loads(line)
        answer = [normalize(a) for a in data["answer"]]
        answers.append(answer)

    predictions = []
    with open(prediction_file) as f:
        for line in f:
            data = json.loads(line)
            prediction = normalize(data[0]["span"])
            predictions.append(prediction)

    exact_match = 0
    for i in range(len(predictions)):
        match_fn = regex_match_score if regex else exact_match_score
        exact_match += metric_max_over_ground_truths(match_fn, predictions[i],
                                                     answers[i])
    total = len(predictions)
    exact_match = 100.0 * exact_match / total
    print({"exact_match": exact_match})
コード例 #2
0
def has_answer(answer, doc_id, match):
    """Check if a document contains an answer string.

    If `match` is string, token matching is done between the text and answer.
    If `match` is regex, we search the whole text with the regex.
    """
    global PROCESS_DB, PROCESS_TOK
    text = PROCESS_DB.get_doc_text(doc_id)
    text = utils.normalize(text)
    if match == 'string':
        # Answer is a list of possible strings
        text = PROCESS_TOK.tokenize(text).words(uncased=True)
        for single_answer in answer:
            single_answer = utils.normalize(single_answer)
            single_answer = PROCESS_TOK.tokenize(single_answer)
            single_answer = single_answer.words(uncased=True)
            for i in range(0, len(text) - len(single_answer) + 1):
                if single_answer == text[i:i + len(single_answer)]:
                    return True
    elif match == 'regex':
        # Answer is a regex
        single_answer = utils.normalize(answer[0])
        if regex_match(text, single_answer):
            return True
    return False
コード例 #3
0
    def exact_match(self, line):

        noun_phrases = self.get_noun_phrases(line)
        wiki_results = self.get_doc_for_claim(noun_phrases)
        wiki_results = list(set(wiki_results))

        claim = normalize(line['claim'])
        claim = claim.replace(".", "")
        claim = claim.replace("-", " ")
        words = [self.proter_stemm.stem(word.lower()) for word in self.tokenizer(claim)]
        words = set(words)
        predicted_pages = self.np_conc(noun_phrases)

        for page in wiki_results:
            page = normalize(page)
            processed_page = re.sub("-LRB-.*?-RRB-", "", page)
            processed_page = re.sub("_", " ", processed_page)
            processed_page = re.sub("-COLON-", ":", processed_page)
            processed_page = processed_page.replace("-", " ")
            processed_page = processed_page.replace("–", " ")
            processed_page = processed_page.replace(".", "")
            page_words = [self.proter_stemm.stem(word.lower()) for word in self.tokenizer(processed_page) if
                          len(word) > 0]

            if all([item in words for item in page_words]):
                if ':' in page:
                    page = page.replace(":", "-COLON-")
                predicted_pages.append(page)
        predicted_pages = list(set(predicted_pages))
        # print("claim: ",claim)
        # print("nps: ",noun_phrases)
        # print("wiki_results: ",wiki_results)
        # print("predicted_pages: ",predicted_pages)
        # print("evidence:",line['evidence'])
        return noun_phrases, wiki_results, predicted_pages
コード例 #4
0
ファイル: eval.py プロジェクト: athiwatp/DrQA
def evaluate(dataset_file, prediction_file, regex=False):
    print('-' * 50)
    print('Dataset: %s' % dataset_file)
    print('Predictions: %s' % prediction_file)

    answers = []
    for line in open(args.dataset):
        data = json.loads(line)
        answer = [normalize(a) for a in data['answer']]
        answers.append(answer)

    predictions = []
    with open(prediction_file) as f:
        for line in f:
            data = json.loads(line)
            prediction = normalize(data[0]['span'])
            predictions.append(prediction)

    exact_match = 0
    for i in range(len(predictions)):
        match_fn = regex_match_score if regex else exact_match_score
        exact_match += metric_max_over_ground_truths(
            match_fn, predictions[i], answers[i]
        )
    total = len(predictions)
    exact_match = 100.0 * exact_match / total
    print({'exact_match': exact_match})
コード例 #5
0
ファイル: eval.py プロジェクト: athiwatp/DrQA
def has_answer(answer, doc_id, match):
    """Check if a document contains an answer string.

    If `match` is string, token matching is done between the text and answer.
    If `match` is regex, we search the whole text with the regex.
    """
    global PROCESS_DB, PROCESS_TOK
    text = PROCESS_DB.get_doc_text(doc_id)
    text = utils.normalize(text)
    if match == 'string':
        # Answer is a list of possible strings
        text = PROCESS_TOK.tokenize(text).words(uncased=True)
        for single_answer in answer:
            single_answer = utils.normalize(single_answer)
            single_answer = PROCESS_TOK.tokenize(single_answer)
            single_answer = single_answer.words(uncased=True)
            for i in range(0, len(text) - len(single_answer) + 1):
                if single_answer == text[i: i + len(single_answer)]:
                    return True
    elif match == 'regex':
        # Answer is a regex
        single_answer = utils.normalize(answer[0])
        if regex_match(text, single_answer):
            return True
    return False
コード例 #6
0
def gold_evidence_to_list(gold_evidences):
    evidences = []
    for e_set in gold_evidences:
        evidence_set = []
        for e in e_set:
            evidence_set.append(
                normalize(str(e[-2])) + '§§§' + normalize(str(e[-1])))
        evidences.append(evidence_set)
    return evidences
コード例 #7
0
def get_contents(filename):
    """Parse the contents of a file. Each line is a JSON encoded document."""
    global PREPROCESS_FN
    documents = []
    with bz2.open(filename, 'rb') as f:
        for line in f:
            # Parse document
            doc = json.loads(line)
            # Maybe preprocess the document with custom function
            if PREPROCESS_FN:
                doc = PREPROCESS_FN(doc)
            # Skip if it is empty or None
            if not doc:
                continue
            # Add the document
            assert len(doc['text']) == len(doc['text_with_links'])
            _text, _text_with_links = pickle.dumps(doc['text']), pickle.dumps(
                doc['text_with_links'])

            _text_ner = []
            for sent in doc['text']:
                ent_list = [(ent.text, ent.start_char, ent.end_char,
                             ent.label_) for ent in nlp(sent).ents]
                _text_ner.append(ent_list)
            _text_ner_str = pickle.dumps(_text_ner)

            documents.append(
                (utils.normalize(doc['id']), doc['url'], doc['title'], _text,
                 _text_with_links, _text_ner_str, len(doc['text'])))

    return documents
コード例 #8
0
ファイル: build_db.py プロジェクト: mazzzystar/DrQAChinese
def get_contents(filename):
    """Parse the contents of a file. Each line is a JSON encoded document."""
    global PREPROCESS_FN
    documents = []

    count = 0
    err_count = 0
    with open(filename, encoding="utf8", errors='ignore') as f:
        for line in f:
            count += 1
            try:
                # Parse document
                doc = json.loads(line)
                # Maybe preprocess the document with custom function
                if PREPROCESS_FN:
                    doc = PREPROCESS_FN(doc)
                # Skip if it is empty or None
                if not doc:
                    continue
                # Add the document
                documents.append((utils.normalize(doc['id']), doc['title'], doc['text']))
            except:
                err_count += 1
    print("count={} err={}".format(count, err_count))
    return documents
コード例 #9
0
def rerankDocs(questions, answers, closest_docs, db):

    documents = []
    for doc_ids, _ in closest_docs:
        batch = []
        for doc_id in doc_ids:
            text = db.get_doc_text(doc_id)
            batch.append((utils.normalize(text), doc_id))
        documents.append(batch)
    return_documents = []
    for question, docs in zip(questions, documents):
        samples = []
        for i in range(len(docs)):
            samples.append(
                InputExample(guid="%s" % i, text_a=docs[i][0],
                             text_b=question))
        preds = getPredictions(samples)
        batch = []
        count = 0
        for i in range(len(preds)):
            if preds[i] == 1 and count < 5:
                batch.append(docs[i][1])
            elif count >= 5:
                break
        return_documents.append(batch)
    return zip(answers, return_documents, questions)
コード例 #10
0
def calculate(qa_pair, n_results):
    results = []
    question = qa_pair['question']
    answer = utils.normalize(qa_pair['answer'])

    # execute query
    res = search(query=question, n_results=n_results)

    # calculate performance metrics from query response info

    binary_results = [
        int(answer.lower() in doc["text"].lower()) for doc in res
    ]
    ans_in_res = int(any(binary_results))

    #Calculate average precision
    m = 0
    precs = []

    for i, val in enumerate(binary_results):
        if val == 1:
            m += 1
            precs.append(sum(binary_results[:i + 1]) / (i + 1))

    ap = (1 / m) * np.sum(precs) if m else 0

    rec = (question, answer, ans_in_res, ap)
    results.append(rec)
    return results
コード例 #11
0
def get_contents(filename):
    documents = []
    docred_data = json.load(open(filename))
    id = 0
    title_to_id = {}
    for data in tqdm(docred_data):
        text = []
        title = ""
        doc_id = -1
        for d in data['context']:
            title = d[0][:-2]
            if title in title_to_id:
                doc_id = title_to_id[title]
            else:
                doc_id = id
            text.extend(d[1])

        id += 1

        if title not in title_to_id:
            _text = pickle.dumps(text)
            _text_with_links = pickle.dumps([])
            _text_ner = []
            for sent in text:
                ent_list = [(ent.text, ent.start_char, ent.end_char,
                             ent.label_) for ent in nlp(sent).ents]
                _text_ner.append(ent_list)
            _text_ner_str = pickle.dumps(_text_ner)

            documents.append((utils.normalize(str(doc_id)), "", title, _text,
                              _text_with_links, _text_ner_str, len(text)))
            title_to_id[title] = doc_id

    return documents
コード例 #12
0
    def closest_docs(self, question_, k=5):
        """Closest docs by dot product between query and documents
        in tfidf weighted word vector space.
        """
        doc_scores = []
        doc_ids = []
        doc_texts = []
        words = self.parse(utils.normalize(question_))
        query = ' '.join(words)
        if not query:
            logger.warning('has no query!')
            return doc_ids, doc_scores, doc_texts
        search_results = self._run_lucene(query, k)
        for result in search_results.split('\n'):

            result_elements = result.split(TEXT_FLAG)
            # print(result_elements)
            id_and_score = result_elements[0].split()
            if len(id_and_score) < 2:
                logger.warning('query failed for question: %s' % question_)
                continue
            doc_id = id_and_score[0]
            doc_score = id_and_score[1]
            text = result_elements[1].strip()

            doc_ids.append(doc_id)
            doc_scores.append(doc_score)
            doc_texts.append(text)
            # print('id:', doc_id, 'ds:', doc_score, 'text:', text)
        # logger.debug('question_d:%s, query:%s, doc_ids:%s, doc_scores:%s'
        #              % (question_, query, doc_ids, doc_scores))
        return doc_ids, doc_scores, doc_texts
コード例 #13
0
def gen_query(question_):
    normalized = utils.normalize(question_)
    tokenizer = tokenizers.get_class('simple')()
    tokens = tokenizer.tokenize(normalized)
    words = tokens.ngrams(n=1, uncased=True, filter_fn=utils.filter_ngram)
    query_ = ' '.join(words)
    return query_
コード例 #14
0
def generate_submission(_predictions, _ids, test_set_path, submission_path):
    """
    Generate submission file for shared task: http://fever.ai/task.html
    :param _ids:
    :param _predictions:
    :param test_set_path:
    :param submission_path:
    :return:
    """
    from common.dataset.reader import JSONLineReader
    from tqdm import tqdm
    import json
    _predictions_with_id = list(zip(_ids, _predictions))
    jlr = JSONLineReader()
    json_lines = jlr.read(test_set_path)
    os.makedirs(os.path.dirname(os.path.abspath(submission_path)), exist_ok=True)
    with open(submission_path, 'w') as f:
        for line in tqdm(json_lines):
            for i, evidence in enumerate(line['predicted_evidence']):
                line['predicted_evidence'][i][0] = normalize(evidence[0])
            _id = line['id']
            _pred_label = prediction_2_label(2)
            for _pid, _plabel in _predictions_with_id:
                if _pid == _id:
                    _pred_label = prediction_2_label(_plabel)
                    break
            obj = {"id": _id,"predicted_label": _pred_label,"predicted_evidence": line['predicted_evidence']}
            f.write(json.dumps(obj))
            f.write('\n')
コード例 #15
0
def get_wiki_entry(name):

    if normalize(clean(name)) in idx:
        return normalize(clean(name))
    else:
        try:
            if name[0].islower():
                return normalize(
                    clean(get_wiki_entry(name[0].upper() + name[1:])))
            else:
                return normalize(
                    clean(
                        get_wiki_entry(
                            recursive_redirect_lookup(redirects,
                                                      redirects[name]))))
        except:
            return None
コード例 #16
0
def evidence_macro_recall(instance, max_evidence=None):
    # We only want to score F1/Precision/Recall of recalled evidence for NEI claims
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # If there's no evidence to predict, return 1
        if len(instance["evidence"]) == 0 or all([len(eg) == 0 for eg in instance]):
            return 1.0, 1.0

        predicted_evidence = instance["predicted_evidence"] if max_evidence is None else \
                                                                        instance["predicted_evidence"][:max_evidence]

        for evidence_group in instance["evidence"]:
            evidence = [[normalize(e[0]), e[1]] for e in evidence_group]
            if all([[normalize(item[0]),item[1]] in predicted_evidence for item in evidence]):
                # We only want to score complete groups of evidence. Incomplete groups are worthless.
                return 1.0, 1.0
        return 0.0, 1.0
    return 0.0, 0.0
コード例 #17
0
 def get_doc_lines(self, doc_id):
     """Fetch the raw text of the doc for 'doc_id'."""
     cursor = self.connection.cursor()
     cursor.execute("SELECT lines FROM documents WHERE id = ?",
                    (utils.normalize(doc_id), ))
     result = cursor.fetchone()
     cursor.close()
     return result if result is None else result[0]
コード例 #18
0
def check_has_answer(answer, doc_ids, PROCESS_DB, PROCESS_TOK):

    paragraphs = [utils.normalize(PROCESS_DB.get_doc_text(doc_id)) for doc_id in doc_ids]
    
    has_answ = []
    for paragraph in paragraphs:
        has_answ.append(check_ans(answer, paragraph, PROCESS_TOK))

    return paragraphs, has_answ
コード例 #19
0
ファイル: analyze_answer.py プロジェクト: SBUNetSys/DeQA
def get_rank(prediction_, answer_, use_regex_=False):
    for rank_, entry in enumerate(prediction_):
        if use_regex_:
            match_fn = regex_match_score
        else:
            match_fn = exact_match_score
        exact_match = metric_max_over_ground_truths(match_fn, normalize(entry['span']), answer_)
        if exact_match:
            return rank_ + 1
    return 1000
コード例 #20
0
def doc_macro_recall(instance, max_pages=None):
    if instance["label"].upper() != "NOT ENOUGH INFO":

        # If there's no evidence to predict,return 1
        if len(instance["evidence"]) == 0 or all([len(eg) == 0 for eg in instance]):
            return 1.0, 1.0

        # print(instance)
        predicted_pages = instance["predicted_pages"] if max_pages is None else instance['predicted_pages'][:max_pages]
        predicted_pages = [normalize(page) for page in predicted_pages]

        for evidence_group in instance["evidence"]:
            documents = set(normalize(e[2]) for e in evidence_group)

            if all([normalize(doc) in predicted_pages for doc in documents]):
                return 1.0, 1.0
        return 0.0, 1.0

    return 0.0, 0.0
コード例 #21
0
def rerankDocs(questions, answers, closest_docs, db, classifier, max_seq, args):

    uuid=0
    model, tokenizer = init_classifier(classifier, args)
    documents = []
    for doc_ids, _ in closest_docs:
        batch = []
        for doc_id in doc_ids:
            text = db.get_doc_text(doc_id)
            batch.append((utils.normalize(text), doc_id))
        documents.append(batch)
    
    return_documents = []

    for question, docs in tqdm(zip(questions, documents)):
        len_question = len(tokenizer.tokenize(question))
        samples = []
        doc_lengths = []
        for doc, doc_id in docs:
            paragraphs = doc.strip("\n\n\n")
            paragraphs = paragraphs.split("\n\n")

            contexts = []
            to_add = ""
            for temp in paragraphs:
                len_temp = len(tokenizer.tokenize(temp))
                if len_temp > (max_seq - len_question):
                    if len(to_add) > 1:
                        contexts.append(InputExample(guid=uuid, text_a=to_add, text_b=question, label="not_answerable"))
                        uuid+=1
                        to_add = ""
                    contexts.append(InputExample(guid=uuid, text_a=temp, text_b=question, label="not_answerable"))
                    uuid+=1
                elif len(tokenizer.tokenize(to_add)) + len_temp <= (max_seq - len_question):
                    to_add = to_add + temp
                else:
                    contexts.append(InputExample(guid=uuid, text_a=to_add, text_b=question, label="not_answerable"))
                    uuid+=1
                    to_add = temp
            if len(to_add) > 1:
                contexts.append(InputExample(guid=uuid, text_a=to_add, text_b=question, label="not_answerable"))
                uuid+=1
            samples = samples + contexts
            doc_lengths.append(len(contexts))
       
        preds = getPredictions(samples, model, tokenizer, max_seq, doc_lengths, args)
        
        tobe_sorted = []
        for pred, doc in zip(preds, docs):
            tobe_sorted.append((pred, doc[1]))
        
        tobe_sorted.sort(key= lambda x : x[0], reverse=True)
        return_documents.append(tobe_sorted[0:5])
    
    return zip(answers, return_documents, questions)
コード例 #22
0
def snippet_parsed_text_2_lines(text):  # this function is new
    lines = []
    segments = text.split('}')
    for segment in segments:
        segment = segment.strip()
        if len(segment) == 0:
            break
        idx = segment.index('{')
        segment = segment[idx + 1:]
        lines.append(clean_text(normalize(segment)))
    return lines
コード例 #23
0
def add_missing(save_path):
    connection = sqlite3.connect(save_path, check_same_thread=False)
    c = connection.cursor()
    c.execute("SELECT id FROM documents")
    all_ids = [r[0] for r in c.fetchall()]
    all_ids_map = {}
    for r in all_ids:
        all_ids_map[r] = 1

    with open('all_docs.json', 'r') as f:
        all_docs = json.load(f)
    parser = HTMLParser()
    missing = []
    for k in all_docs:
        if utils.normalize(k) not in all_ids_map:
            missing.append((utils.normalize(k), parser.unescape(all_docs[k])))
            all_ids_map[k] = 1
    c.executemany("INSERT INTO documents VALUES (?,?)", missing)
    connection.commit()
    connection.close()
コード例 #24
0
def snippet_text_2_lines(text):  # this is the old function
    lines = []
    paragraphs = text.split('</p>')
    for paragraph in paragraphs:
        paragraph = paragraph.strip()
        if paragraph.startswith('<p>'):
            paragraph = paragraph[3:].strip()
            lines_of_paragrah = sent_tokenize(paragraph)
            for line in lines_of_paragrah:
                lines.append(clean_text(normalize(line)))
    return lines
コード例 #25
0
def get_contents(filename, from_md):
    """Parse the contents of a file. Each line is a JSON encoded document, unless from_md is selected."""
    global PREPROCESS_FN
    documents = []
    with open(filename) as f:
        if from_md:
            contents = f.read()
            return [(utils.normalize(filename), contents)]
        else:
            for line in f:
                # Parse document
                doc = json.loads(line)
                # Maybe preprocess the document with custom function
                if PREPROCESS_FN:
                    doc = PREPROCESS_FN(doc)
                # Skip if it is empty or None
                if not doc:
                    continue
                # Add the document
                documents.append((utils.normalize(doc['id']), doc['text']))
    return documents
コード例 #26
0
def check_ans(answer, paragraph, tokenizer):

    text = tokenizer.tokenize(paragraph).words(uncased=True)
    for single_answer in answer:
        single_answer = utils.normalize(single_answer)
        single_answer = PROCESS_TOK.tokenize(single_answer)
        single_answer = single_answer.words(uncased=True)

        for i in range(0, len(text) - len(single_answer) + 1):
            if single_answer == text[i:i + len(single_answer)]:
                return 1
    return 0
コード例 #27
0
def search(query, n_results=5):
    doc_names, doc_scores = ranker.closest_docs(query, k=n_results)

    results = []
    for i in range(len(doc_names)):
        result = {}
        result["score"] = doc_scores[i]
        result["title"] = doc_names[i]
        result["text"] = utils.normalize(get_doc_text(doc_names[i]))
        results.append(result)

    return results
コード例 #28
0
def get_contents(filename):
    """Parse the contents of a file. Each line is a JSON encoded document."""
    global PREPROCESS_FN
    documents = []
    with open(filename) as f:
        for line in f:
            doc = json.loads(line)
            if PREPROCESS_FN:
                doc = PREPROCESS_FN(doc)
            if not doc:
                continue
            documents.append((utils.normalize(doc['id']), doc['text']))
    return documents
コード例 #29
0
def split_and_check_hanswer(answer, doc_id, PROCESS_DB, PROCESS_TOK,
                            tokenizer):
    text = PROCESS_DB.get_doc_text(doc_id)
    text = utils.normalize(text)
    paragraphs = text.strip('\n\n\n')
    paragraphs = paragraphs.split('\n\n')

    paragraphs = recontruct_with_max_seq(paragraphs, tokenizer, 384)

    has_answ = []
    for paragraph in paragraphs:
        has_answ.append(check_ans(answer, paragraph, PROCESS_TOK))

    return paragraphs, has_answ
コード例 #30
0
def process(input, output):
    fin = open(input, 'rb')
    instances = []
    index = 0
    for line in fin:
        object = json.loads(line.decode(ENCODING).strip('\r\n'))
        if 'label' in object:
            label = ''.join(object['label'].split(' '))
        else:
            label = 'REFUTES'
        evidences = object['predicted_evidence']
        claim = object['claim']
        instances.append([index, label, claim, evidences])
        index += 1
    fin.close()
    print(index)

    fout = open(output, 'wb')
    for instance in tqdm(instances):
        index, label, claim, evidences = instance
        for evidence in evidences:
            article = evidence[0]
            location = evidence[1]
            evidence_str = None
            cursor.execute("SELECT * FROM documents WHERE id = ?",
                           (utils.normalize(article), ))
            for row in cursor:
                sentences = row[2].split('\n')
                for sentence in sentences:
                    if sentence == '': continue
                    arr = sentence.split('\t')
                    if not arr[0].isdigit():
                        # print(('Warning: this line from article %s for claim %d is not digit %s\r\n' % (article, i, sentence)).encode(ENCODING))
                        continue
                    line_num = int(arr[0])
                    if len(arr) <= 1: continue
                    sentence = ' '.join(arr[1:])
                    if sentence == '':
                        continue
                    if line_num == location:
                        evidence_str = sentence
                        break
            if evidence_str:
                fout.write(('%s\t%s\t%s\t%s\t%s\t%d\t%s\r\n' %
                            (label, evidence_str, claim, index, evidence[0],
                             evidence[1], evidence[2])).encode(ENCODING))
            else:
                print('Error: cant find %s %d for %s' %
                      (article, location, index))
    fout.close()
コード例 #31
0
def get_squad(file_path):
    """Iterate over all the SQuAD paragraphs (context)."""

    with open(file_path) as dataset_file:
        dataset_json = json.load(dataset_file)
        dataset = dataset_json['data']

    paragraphs = []
    for doc_json in tqdm(dataset):
        title = utils.normalize(doc_json["title"])

        for idx, paragraph_json in enumerate(doc_json['paragraphs']):
            pid = "%s ### %d" % (title, idx)
            text = paragraph_json["context"]
            paragraphs.append((pid, text))

    return paragraphs
コード例 #32
0
ファイル: build_db.py プロジェクト: athiwatp/DrQA
def get_contents(filename):
    """Parse the contents of a file. Each line is a JSON encoded document."""
    global PREPROCESS_FN
    documents = []
    with open(filename) as f:
        for line in f:
            # Parse document
            doc = json.loads(line)
            # Maybe preprocess the document with custom function
            if PREPROCESS_FN:
                doc = PREPROCESS_FN(doc)
            # Skip if it is empty or None
            if not doc:
                continue
            # Add the document
            documents.append((utils.normalize(doc['id']), doc['text']))
    return documents
コード例 #33
0
def get_contents(filename):
    """Parse the contents of a file. Each line is a JSON encoded document."""
    global PREPROCESS_FN
    documents = []
    with open(filename) as f:
        for line in f:
            # Parse document
            doc = json.loads(line)
            # Maybe preprocess the document with custom function
            if PREPROCESS_FN:
                doc = PREPROCESS_FN(doc)
            # Skip if it is empty or None
            if not doc:
                continue
            # Add the document
            documents.append((utils.normalize(doc["id"]), doc["text"], doc["lines"]))
    return documents
コード例 #34
0
ファイル: interactive.py プロジェクト: athiwatp/DrQA
                    help="Specify GPU device id to use")
args = parser.parse_args()

args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
    torch.cuda.set_device(args.gpu)
    logger.info('CUDA enabled (GPU %d)' % args.gpu)
else:
    logger.info('Running on CPU only.')

if args.candidate_file:
    logger.info('Loading candidates from %s' % args.candidate_file)
    candidates = set()
    with open(args.candidate_file) as f:
        for line in f:
            line = utils.normalize(line.strip()).lower()
            candidates.add(line)
    logger.info('Loaded %d candidates.' % len(candidates))
else:
    candidates = None

logger.info('Initializing pipeline...')
DrQA = pipeline.DrQA(
    cuda=args.cuda,
    fixed_candidates=candidates,
    reader_model=args.reader_model,
    ranker_config={'options': {'tfidf_path': args.retriever_model}},
    db_config={'options': {'db_path': args.doc_db}},
    tokenizer=args.tokenizer
)