Exemple #1
0
def improve_question_type_and_answer(data, e2i):
    '''Improve the result of the judgement of question type in training data with other information.
    
    If the question is a special question(type 0), answer_id is the index of final answer node. Otherwise answer_ids are
    the indices of two compared nodes and the result of comparison(0 / 1).
    This part is not very important to the overall results, but avoids Runtime Errors in rare cases.
    
    Args:
        data (Json): Refined distractor-setting samples.
        e2i (dict): entity2index dict.
    
    Returns:
        (int, int or (int, int, 0 / 1), string): question_type, answer_id and answer_entity.
    '''
    question_type = judge_question_type(data['question'])
    # fix judgement by answer
    if data['answer'] == 'yes' or data['answer'] == 'no':
        question_type = 1
        answer_entity = data['answer']
    else:
        # check whether the answer can be extracted as a span
        answer_entity = fuzzy_retrieve(data['answer'], e2i, 'distractor', 80)
        if answer_entity is None:
            raise ValueError('Cannot find answer: {}'.format(data['answer']))

    if question_type == 0:
        answer_id = e2i[answer_entity]
    elif len(data['Q_edge']) != 2:
        if question_type == 1:
            raise ValueError(
                'There must be 2 entities in "Q_edge" for type 1 question.')
        elif question_type == 2:  # Judgement error, should be type 0
            question_type = 0
            answer_id = e2i[answer_entity]
    else:
        answer_id = [e2i[data['Q_edge'][0][0]],
                     e2i[data['Q_edge'][1][0]]]  # compared nodes
        if question_type == 1:
            answer_id.append(int(data['answer'] == 'yes'))
        elif question_type == 2:
            if data['answer'] == data['Q_edge'][0][1]:
                answer_id.append(0)
            elif data['answer'] == data['Q_edge'][1][1]:
                answer_id.append(1)
            else:  # cannot exactly match an option
                score = (fuzz.partial_ratio(data['answer'],
                                            data['Q_edge'][0][1]),
                         fuzz.partial_ratio(data['answer'],
                                            data['Q_edge'][1][1]))
                if score[0] < 50 and score[1] < 50:
                    raise ValueError(
                        'There is no exact match in selecting question. answer: {}'
                        .format(data['answer']))
                else:
                    answer_id.append(0 if score[0] > score[1] else 1)
    return question_type, answer_id, answer_entity
Exemple #2
0
def convert_question_to_samples_bundle(tokenizer, data: 'Json refined', neg=2):
    context = dict(data['context'])
    gold_sentences_set = dict([
        ((para, sen), edges) for para, sen, edges in data['supporting_facts']
    ])
    e2i = {}
    i2e = []
    for entity, sens in context.items():
        assert not entity in e2i
        e2i[entity] = len(i2e)
        i2e.append(entity)
    prev = [[]] * len(i2e)

    ids, hop_start_weights, hop_end_weights, ans_start_weights, ans_end_weights, segment_ids, sep_positions, additional_nodes = [], [], [], [], [], [], [], []
    tokenized_question = ['[CLS]'] + tokenizer.tokenize(
        data['question']) + ['[SEP]']
    for title_x, sen, edges in data[
            'supporting_facts']:  # TODO: match previous sentence
        for title_y, matched, l, r in edges:
            if title_y not in e2i:  # answer
                assert data['answer'] == title_y
                e2i[title_y] = len(i2e)
                i2e.append(title_y)
                prev.append([])
            if title_x != title_y:
                y = e2i[title_y]
                prev[y] = prev[y] + tokenizer.tokenize(
                    context[title_x][sen]) + ['[SEP]']
    question_type = judge_question_type(data['question'])

    # fix by answer:
    if data['answer'] == 'yes' or data['answer'] == 'no':
        question_type = 1
        answer_entity = data['answer']
    else:
        # find answer entity
        answer_entity = fuzzy_retrieve(data['answer'], e2i, 'distractor', 80)
        if answer_entity is None:
            raise ValueError('Cannot find answer: {}'.format(data['answer']))

    if question_type == 0:
        answer_id = e2i[answer_entity]
    elif len(data['Q_edge']) != 2:
        if question_type == 1:
            raise ValueError(
                'There must be 2 entities in "Q_edge" for type 1 question.')
        elif question_type == 2:
            # print('Convert type 2 question to 0.\n Question:{}'.format(data['question']))
            question_type = 0
            answer_id = e2i[answer_entity]
    else:
        answer_id = [e2i[data['Q_edge'][0][0]], e2i[data['Q_edge'][1][0]]]
        if question_type == 1:
            answer_id.append(int(data['answer'] == 'yes'))
        elif question_type == 2:
            if data['answer'] == data['Q_edge'][0][1]:
                answer_id.append(0)
            elif data['answer'] == data['Q_edge'][1][1]:
                answer_id.append(1)
            else:
                score = (fuzz.partial_ratio(data['answer'],
                                            data['Q_edge'][0][1]),
                         fuzz.partial_ratio(data['answer'],
                                            data['Q_edge'][1][1]))
                if score[0] < 50 and score[1] < 50:
                    raise ValueError(
                        'There is no exact match in selecting question. answer: {}'
                        .format(data['answer']))
                else:
                    # print('Resolve type 1 or 2 question: {}\n answer: {}'.format(data['question'], data['answer']))
                    answer_id.append(0 if score[0] > score[1] else 1)
        else:
            pass

    for entity, sens in context.items():
        num_hop, num_ans = 0, 0
        tokenized_all = tokenized_question + prev[e2i[entity]]
        if len(tokenized_all) > 512:
            tokenized_all = tokenized_all[:512]
            print('PREV TOO LONG, id: {}'.format(data['_id']))
        segment_id = [0] * len(tokenized_all)
        sep_position = []
        hop_start_weight = [0] * len(tokenized_all)
        hop_end_weight = [0] * len(tokenized_all)
        ans_start_weight = [0] * len(tokenized_all)
        ans_end_weight = [0] * len(tokenized_all)

        for sen_num, sen in enumerate(sens):
            tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]']
            if len(tokenized_all) + len(tokenized_sen) > 512 or sen_num > 15:
                break
            # if sen_num > 10:
            #     raise ValueError('Too many sentences in context: {}'.format(sens))
            tokenized_all += tokenized_sen
            segment_id += [sen_num + 1] * len(tokenized_sen)
            sep_position.append(len(tokenized_all) - 1)
            hs_weight = [0] * len(tokenized_sen)
            he_weight = [0] * len(tokenized_sen)
            as_weight = [0] * len(tokenized_sen)
            ae_weight = [0] * len(tokenized_sen)
            if (entity, sen_num) in gold_sentences_set:
                tmp = gold_sentences_set[(entity, sen_num)]
                intervals = find_start_end_after_tokenized(
                    tokenizer, tokenized_sen,
                    [matched for _, matched, _, _ in tmp])
                for j, (l, r) in enumerate(intervals):
                    if tmp[j][0] == answer_entity or question_type > 0:
                        as_weight[l] = ae_weight[r] = 1
                        num_ans += 1
                    else:
                        hs_weight[l] = he_weight[r] = 1
                        num_hop += 1
            hop_start_weight += hs_weight
            hop_end_weight += he_weight
            ans_start_weight += as_weight
            ans_end_weight += ae_weight

        assert len(tokenized_all) <= 512
        # for i in range(len(start_weight)):
        #     start_weight[i] /= max(num_spans, 1)
        #     end_weight[i] /= max(num_spans, 1)
        if 1 not in hop_start_weight:
            hop_start_weight[0] = 0.1
        if 1 not in ans_start_weight:
            ans_start_weight[0] = 0.1

        ids.append(tokenizer.convert_tokens_to_ids(tokenized_all))
        sep_positions.append(sep_position)
        segment_ids.append(segment_id)
        hop_start_weights.append(hop_start_weight)
        hop_end_weights.append(hop_end_weight)
        ans_start_weights.append(ans_start_weight)
        ans_end_weights.append(ans_end_weight)

    n = len(context)
    edges_in_bundle = []
    if question_type == 0:
        # find all edges and prepare forbidden set(containing answer) for negative sampling
        forbidden = set([])
        for para, sen, edges in data['supporting_facts']:
            for x, matched, l, r in edges:
                edges_in_bundle.append((e2i[para], e2i[x]))
                if x == answer_entity:
                    forbidden.add((para, sen))
        if answer_entity not in context and answer_entity in e2i:
            n += 1
            tokenized_all = tokenized_question + prev[e2i[answer_entity]]
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('ANSWER TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(
                tokenizer.convert_tokens_to_ids(tokenized_all))
        for i in range(neg):
            # build negative node n+i
            father_para = random.choice(list(context.keys()))
            father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                father_para = random.choice(list(context.keys()))
                father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                neg -= 1
                continue
            tokenized_all = tokenized_question + tokenizer.tokenize(
                context[father_para][father_sen]) + ['[SEP]']
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('NEG TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(
                tokenizer.convert_tokens_to_ids(tokenized_all))
            edges_in_bundle.append((e2i[father_para], n + i))
        n += neg
    assert n == len(additional_nodes) + len(context)

    adj = torch.eye(n) * 2
    for x, y in edges_in_bundle:
        adj[x, y] = 1
    adj /= torch.sum(adj, dim=0, keepdim=True)

    _id = data['_id']
    ret = Bundle()
    for field in FIELDS:
        setattr(ret, field, eval(field))
    return ret