Exemplo n.º 1
0
def preprocess(data):
    processed_data = []
    for text in data:
        try:
            answer_span = get_updated_answer_indices(text)

            paragraph = evaluate.normalize_answer(text['paragraph']).split()
            question = evaluate.normalize_answer(text['question']).split()

            if len(paragraph) > constants.MAX_PARAGRAPH_LENGTH:
                continue

            paragraph = [
                glove_embeddings.get_word_index(word) for word in paragraph
            ]
            question = [
                glove_embeddings.get_word_index(word) for word in question
            ]
        except IndexError:
            continue

        processed_data += [{
            'paragraph': paragraph,
            'question': question,
            'answer_start': answer_span[0],
            'answer_end': answer_span[1]
        }]

    return processed_data
Exemplo n.º 2
0
def check_correctness(df, retrieving_procedure):
    """
    Given a dataframe (containing an 'answer' column) and a retrieving function to obtain an answer from each record of the
    dataframe itself, it checks if the real and the retrieved answers are equal and, if not, it appends the answer to a new
    dataframe of wrong answers which is returned.
    """
    correct_answers = []
    wrong_answers = []
    for record_id, record in df.iterrows():
        answer = record['answer']
        n_answer = normalize_answer(answer)
        retrieved = retrieving_procedure(record)
        n_retrieved = normalize_answer(retrieved)
        if n_answer != n_retrieved:
            wrong_answers.append((record_id, answer, n_answer, retrieved, n_retrieved))
        else:
            correct_answers.append((record_id, answer, n_answer, retrieved, n_retrieved))
    correct_df = pd.DataFrame(
        correct_answers,
        columns=['id', 'answer', 'normalized answer', 'retrieved', 'normalized retrieved']
    ).set_index(['id'])
    wrong_df = pd.DataFrame(
        wrong_answers,
        columns=['id', 'answer', 'normalized answer', 'retrieved', 'normalized retrieved']
    ).set_index(['id'])
    return correct_df, wrong_df
Exemplo n.º 3
0
def get_encoded_input(paragraph, question):
    paragraph = evaluate.normalize_answer(paragraph).split()
    question = evaluate.normalize_answer(question).split()

    if len(paragraph) > constants.MAX_PARAGRAPH_LENGTH:
        return

    paragraph = [glove_embeddings.get_word_index(word) for word in paragraph]
    question = [glove_embeddings.get_word_index(word) for word in question]

    return (paragraph, question)
Exemplo n.º 4
0
def get_updated_answer_indices(entry):
    all_occurrences = [
        match.start() for match in re.finditer(re.escape(entry['answer']),
                                               entry['paragraph'])
    ]
    answer_start = all_occurrences.index(entry['answer_start'])

    normal_p = evaluate.normalize_answer(entry['paragraph']).split()
    normal_a = evaluate.normalize_answer(entry['answer']).split()

    return find_sublist(normal_p, normal_a)[answer_start]
Exemplo n.º 5
0
def make_squad_examples(
    data,
    tokenizer,
    word2id,
    name=None,
    max_context_len=300,
    max_answer_len=10,
    max_question_len=20,
):
    """
    Given a SQuAD dataset, builds a list of example dicts (see implementation).
    """
    examples = []
    total = 0
    total_passed = 0
    skipped = defaultdict(lambda: 0)
    span2position = make_span2position(seq_size=max_context_len,
                                       max_len=max_answer_len)
    position2span = {v: k for k, v in span2position.items()}

    for line in tqdm(data['data'], desc=name):
        title = line['title']

        for paragraph in line['paragraphs']:
            # Extract context
            context = paragraph['context']
            context = fix_double_quotes(context)
            context_tokens = tokenize(context, tokenizer=tokenizer)

            if max_context_len and len(context_tokens) > max_context_len:
                skipped['context too long'] += len(paragraph['qas'])
                total += len(paragraph['qas'])
                continue

            answer_map = index_by_starting_character(context, context_tokens)

            for qa in paragraph['qas']:
                # Extract question
                question = qa['question']
                question = fix_double_quotes(question)
                question_tokens = tokenize(question, tokenizer=tokenizer)

                # Extract answer
                answer = qa['answers'][0]['text']
                answer = fix_double_quotes(answer)
                answer_tokens = tokenize(answer, tokenizer=tokenizer)

                if max_answer_len and len(answer_tokens) > max_answer_len:
                    skipped['answer too long'] += 1
                    total += 1
                    continue

                answer_start = qa['answers'][0]['answer_start']
                answer_end = answer_start + len(answer)

                # Find answer span
                try:
                    last_word_answer = len(
                        answer_tokens[-1])  # add one to get the first char

                    _, span_start = answer_map[
                        answer_start]  # start token index
                    _, span_end = answer_map[answer_end - last_word_answer]

                    extracted_answer = context_tokens[span_start:span_end + 1]
                    extracted_answer = ' '.join(extracted_answer)
                    extracted_answer = evaluate.normalize_answer(
                        extracted_answer)

                    actual_clean = evaluate.normalize_answer(answer)

                    assert extracted_answer == actual_clean, f'{extracted_answer} != {actual_clean}'

                    span_positions = [span2position[(span_start, span_end)]]

                    s, e = position2span[span_positions[0]]
                    assert ' '.join(context_tokens[s:e+1]) == ' '.join(answer_tokens), \
                        'Extracted span does not match answer'

                    correct_spans = np.asarray(
                        list({
                            k: v
                            for k, v in span2position.items()
                            if np.all(np.asarray(k) < len(context_tokens))
                        }.values()))
                    span_mask = np.zeros(len(span2position))
                    span_mask[correct_spans] = 1

                    example = {
                        'title':
                        title,
                        'context_raw':
                        context_tokens,
                        'question_raw':
                        question_tokens,
                        'answer_raw':
                        answer_tokens,
                        'context':
                        pad_seq([word2id[w] for w in context_tokens],
                                maxlen=max_context_len),
                        'question':
                        pad_seq([word2id[w] for w in question_tokens],
                                maxlen=max_question_len),
                        'answer':
                        pad_seq([word2id[w] for w in answer_tokens],
                                maxlen=max_answer_len),
                        'context_len':
                        len_or_maxlen(context_tokens, max_context_len),
                        'question_len':
                        len_or_maxlen(question_tokens, max_question_len),
                        'answer_len':
                        len_or_maxlen(answer_tokens, max_answer_len),
                        'starts': [span_start],
                        'ends': [span_end],
                        'span_positions':
                        span_positions,
                        'span_mask':
                        span_mask,
                        'label':
                        np.asarray([
                            1 if x in span_positions else 0
                            for x in span2position.values()
                        ])
                    }

                    total_passed += 1
                    total += 1

                    examples.append(example)

                except (AssertionError, KeyError) as e:
                    skipped['error finding span'] += 1
                    total += 1
                    continue

                examples.append(example)

    total_skipped = sum(skipped.values())
    ratio_skipped = total_skipped / total if total != 0 else 0
    logging.info(f'max_context_len: {max_context_len}')
    logging.info(f'max_answer_len: {max_answer_len}')
    logging.info(f'skipped {skipped}/{total}\t({ratio_skipped})')
    print(json.dumps(skipped, indent=4))
    print(f'ratio skipped: {ratio_skipped}')
    print(f'{total_passed} examples PASSED')
    return examples
Exemplo n.º 6
0
    model.load_state_dict(torch.load('model.pt'))
    model = model.to(device)
    print("Model loaded.")

    model.eval()
    print("Starting evaluation...")
    starts, ends = [], []
    num_batches = len(loader)
    for idx, input in enumerate(loader):
        if (idx + 1) % 100 == 0:
            print(f'Batch {idx + 1:{len(str(num_batches))}}/{num_batches}')
        with torch.no_grad():
            s, e = model(input.to(device))
        starts.append(s)
        ends.append(e)
    print("Evaluation completed.")

    df['pred_start'] = [s.item() for ss in starts for s in ss]
    df['pred_end'] = [e.item() for ee in ends for e in ee]

    print("Retrieving predictions...")
    predictions = {}
    for record_id, record in df.iterrows():
        retrieved = retrieving_procedure(record)
        n_retrieved = normalize_answer(retrieved)
        predictions[record_id] = n_retrieved
    print("Finish retrieving.")

    with open('predictions.json', 'w') as f:
        json.dump(predictions, f)
Exemplo n.º 7
0
def get_score(p):
    c1, c2 = p
    c1_tokens = normalize_answer(c1).split()
    c2_tokens = normalize_answer(c2).split()
    common = Counter(c1_tokens) & Counter(c2_tokens)
    return sum(common.values())