Ejemplo n.º 1
0
def convert_question_to_samples_bundle(tokenizer, data: 'Json refined', neg=2):
    '''Make training samples.
    
    Convert distractor-setting samples(question + 10 paragraphs + answer + supporting facts) to bundles.
    
    Args:
        tokenizer (BertTokenizer): BERT Tokenizer to transform sentences to a list of word pieces.
        data (Json): Refined distractor-setting samples with gold-only cognitive graphs. 
        neg (int, optional): Defaults to 2. Negative answer nodes to add in every sample.
    
    Raises:
        ValueError: Invalid question type. 

    Returns:
        Bundle: A bundle containing 10 separate samples(including gold and negative samples).
    '''

    context = dict(data['context'])  # all the entities in 10 paragraphs
    gold_sentences_set = dict([
        ((para, sen), edges) for para, sen, edges in data['supporting_facts']
    ])
    e2i, i2e = {}, []  # entity2index, index2entity
    for entity, sens in context.items():
        e2i[entity] = len(i2e)
        i2e.append(entity)
    clues = [[]] * len(i2e)  # pre-extracted clues

    ids, hop_start_weights, hop_end_weights, ans_start_weights, ans_end_weights, segment_ids, sep_positions, additional_nodes = [], [], [], [], [], [], [], []
    tokenized_question = ['[CLS]'] + tokenizer.tokenize(
        data['question']) + ['[SEP]']

    # Extract clues for entities in the gold-only cogntive graph
    for entity_x, sen, edges in data['supporting_facts']:
        for entity_y, _, _, _ in edges:
            if entity_y not in e2i:  # entity y must be the answer
                assert data['answer'] == entity_y
                e2i[entity_y] = len(i2e)
                i2e.append(entity_y)
                clues.append([])
            if entity_x != entity_y:
                y = e2i[entity_y]
                clues[y] = clues[y] + tokenizer.tokenize(
                    context[entity_x][sen]) + ['[SEP]']

    question_type, answer_id, answer_entity = improve_question_type_and_answer(
        data, e2i)

    # Construct training samples
    for entity, para in context.items():
        num_hop, num_ans = 0, 0
        tokenized_all = tokenized_question + clues[e2i[entity]]
        if len(tokenized_all) > 512:  # BERT-base accepts at most 512 tokens
            tokenized_all = tokenized_all[:512]
            print('CLUES TOO LONG, id: {}'.format(data['_id']))
        # initialize a sample for ``entity''
        sep_position = []
        segment_id = [0] * len(tokenized_all)
        hop_start_weight = [0] * len(tokenized_all)
        hop_end_weight = [0] * len(tokenized_all)
        ans_start_weight = [0] * len(tokenized_all)
        ans_end_weight = [0] * len(tokenized_all)

        for sen_num, sen in enumerate(para):
            tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]']
            if len(tokenized_all) + len(tokenized_sen) > 512 or sen_num > 15:
                break
            tokenized_all += tokenized_sen
            segment_id += [sen_num + 1] * len(tokenized_sen)
            sep_position.append(len(tokenized_all) - 1)
            hs_weight = [0] * len(tokenized_sen)
            he_weight = [0] * len(tokenized_sen)
            as_weight = [0] * len(tokenized_sen)
            ae_weight = [0] * len(tokenized_sen)
            if (entity, sen_num) in gold_sentences_set:
                edges = gold_sentences_set[(entity, sen_num)]
                intervals = find_start_end_after_tokenized(
                    tokenizer, tokenized_sen,
                    [matched for _, matched, _, _ in edges])
                for j, (l, r) in enumerate(intervals):
                    if edges[j][
                            0] == answer_entity or question_type > 0:  # successive node edges[j][0] is answer node
                        as_weight[l] = ae_weight[r] = 1
                        num_ans += 1
                    else:  # edges[j][0] is next-hop node
                        hs_weight[l] = he_weight[r] = 1
                        num_hop += 1
            hop_start_weight += hs_weight
            hop_end_weight += he_weight
            ans_start_weight += as_weight
            ans_end_weight += ae_weight

        assert len(tokenized_all) <= 512
        # if entity is a negative node, train negative threshold at [CLS]
        if 1 not in hop_start_weight:
            hop_start_weight[0] = 0.1
        if 1 not in ans_start_weight:
            ans_start_weight[0] = 0.1

        ids.append(tokenizer.convert_tokens_to_ids(tokenized_all))
        sep_positions.append(sep_position)
        segment_ids.append(segment_id)
        hop_start_weights.append(hop_start_weight)
        hop_end_weights.append(hop_end_weight)
        ans_start_weights.append(ans_start_weight)
        ans_end_weights.append(ans_end_weight)

    # Construct negative answer nodes for task #2(answer node prediction)
    n = len(context)
    edges_in_bundle = []
    if question_type == 0:
        # find all edges and prepare forbidden set(containing answer) for negative sampling
        forbidden = set([])
        for para, sen, edges in data['supporting_facts']:
            for x, matched, l, r in edges:
                edges_in_bundle.append((e2i[para], e2i[x]))
                if x == answer_entity:
                    forbidden.add((para, sen))
        if answer_entity not in context and answer_entity in e2i:
            n += 1
            tokenized_all = tokenized_question + clues[e2i[answer_entity]]
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('ANSWER TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(
                tokenizer.convert_tokens_to_ids(tokenized_all))

        for i in range(neg):
            # build negative answer node n+i
            father_para = random.choice(list(context.keys()))
            father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                father_para = random.choice(list(context.keys()))
                father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                neg -= 1
                continue
            tokenized_all = tokenized_question + tokenizer.tokenize(
                context[father_para][father_sen]) + ['[SEP]']
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('NEG TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(
                tokenizer.convert_tokens_to_ids(tokenized_all))
            edges_in_bundle.append((e2i[father_para], n))
            n += 1

    if question_type >= 1:
        for para, sen, edges in data['supporting_facts']:
            for x, matched, l, r in edges:
                if e2i[para] < n and e2i[x] < n:
                    edges_in_bundle.append((e2i[para], e2i[x]))

    assert n == len(additional_nodes) + len(context)
    adj = torch.eye(n) * 2
    for x, y in edges_in_bundle:
        adj[x, y] = 1
    adj /= torch.sum(adj, dim=0, keepdim=True)

    _id = data['_id']
    ret = Bundle()
    for field in FIELDS:
        setattr(ret, field, eval(field))
    return ret
Ejemplo n.º 2
0
def convert_question_to_samples_bundle(tokenizer, data: 'Json refined', neg=2):
    context = dict(data['context'])
    gold_sentences_set = dict([
        ((para, sen), edges) for para, sen, edges in data['supporting_facts']
    ])
    e2i = {}
    i2e = []
    for entity, sens in context.items():
        assert not entity in e2i
        e2i[entity] = len(i2e)
        i2e.append(entity)
    prev = [[]] * len(i2e)

    ids, hop_start_weights, hop_end_weights, ans_start_weights, ans_end_weights, segment_ids, sep_positions, additional_nodes = [], [], [], [], [], [], [], []
    tokenized_question = ['[CLS]'] + tokenizer.tokenize(
        data['question']) + ['[SEP]']
    for title_x, sen, edges in data[
            'supporting_facts']:  # TODO: match previous sentence
        for title_y, matched, l, r in edges:
            if title_y not in e2i:  # answer
                assert data['answer'] == title_y
                e2i[title_y] = len(i2e)
                i2e.append(title_y)
                prev.append([])
            if title_x != title_y:
                y = e2i[title_y]
                prev[y] = prev[y] + tokenizer.tokenize(
                    context[title_x][sen]) + ['[SEP]']
    question_type = judge_question_type(data['question'])

    # fix by answer:
    if data['answer'] == 'yes' or data['answer'] == 'no':
        question_type = 1
        answer_entity = data['answer']
    else:
        # find answer entity
        answer_entity = fuzzy_retrieve(data['answer'], e2i, 'distractor', 80)
        if answer_entity is None:
            raise ValueError('Cannot find answer: {}'.format(data['answer']))

    if question_type == 0:
        answer_id = e2i[answer_entity]
    elif len(data['Q_edge']) != 2:
        if question_type == 1:
            raise ValueError(
                'There must be 2 entities in "Q_edge" for type 1 question.')
        elif question_type == 2:
            # print('Convert type 2 question to 0.\n Question:{}'.format(data['question']))
            question_type = 0
            answer_id = e2i[answer_entity]
    else:
        answer_id = [e2i[data['Q_edge'][0][0]], e2i[data['Q_edge'][1][0]]]
        if question_type == 1:
            answer_id.append(int(data['answer'] == 'yes'))
        elif question_type == 2:
            if data['answer'] == data['Q_edge'][0][1]:
                answer_id.append(0)
            elif data['answer'] == data['Q_edge'][1][1]:
                answer_id.append(1)
            else:
                score = (fuzz.partial_ratio(data['answer'],
                                            data['Q_edge'][0][1]),
                         fuzz.partial_ratio(data['answer'],
                                            data['Q_edge'][1][1]))
                if score[0] < 50 and score[1] < 50:
                    raise ValueError(
                        'There is no exact match in selecting question. answer: {}'
                        .format(data['answer']))
                else:
                    # print('Resolve type 1 or 2 question: {}\n answer: {}'.format(data['question'], data['answer']))
                    answer_id.append(0 if score[0] > score[1] else 1)
        else:
            pass

    for entity, sens in context.items():
        num_hop, num_ans = 0, 0
        tokenized_all = tokenized_question + prev[e2i[entity]]
        if len(tokenized_all) > 512:
            tokenized_all = tokenized_all[:512]
            print('PREV TOO LONG, id: {}'.format(data['_id']))
        segment_id = [0] * len(tokenized_all)
        sep_position = []
        hop_start_weight = [0] * len(tokenized_all)
        hop_end_weight = [0] * len(tokenized_all)
        ans_start_weight = [0] * len(tokenized_all)
        ans_end_weight = [0] * len(tokenized_all)

        for sen_num, sen in enumerate(sens):
            tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]']
            if len(tokenized_all) + len(tokenized_sen) > 512 or sen_num > 15:
                break
            # if sen_num > 10:
            #     raise ValueError('Too many sentences in context: {}'.format(sens))
            tokenized_all += tokenized_sen
            segment_id += [sen_num + 1] * len(tokenized_sen)
            sep_position.append(len(tokenized_all) - 1)
            hs_weight = [0] * len(tokenized_sen)
            he_weight = [0] * len(tokenized_sen)
            as_weight = [0] * len(tokenized_sen)
            ae_weight = [0] * len(tokenized_sen)
            if (entity, sen_num) in gold_sentences_set:
                tmp = gold_sentences_set[(entity, sen_num)]
                intervals = find_start_end_after_tokenized(
                    tokenizer, tokenized_sen,
                    [matched for _, matched, _, _ in tmp])
                for j, (l, r) in enumerate(intervals):
                    if tmp[j][0] == answer_entity or question_type > 0:
                        as_weight[l] = ae_weight[r] = 1
                        num_ans += 1
                    else:
                        hs_weight[l] = he_weight[r] = 1
                        num_hop += 1
            hop_start_weight += hs_weight
            hop_end_weight += he_weight
            ans_start_weight += as_weight
            ans_end_weight += ae_weight

        assert len(tokenized_all) <= 512
        # for i in range(len(start_weight)):
        #     start_weight[i] /= max(num_spans, 1)
        #     end_weight[i] /= max(num_spans, 1)
        if 1 not in hop_start_weight:
            hop_start_weight[0] = 0.1
        if 1 not in ans_start_weight:
            ans_start_weight[0] = 0.1

        ids.append(tokenizer.convert_tokens_to_ids(tokenized_all))
        sep_positions.append(sep_position)
        segment_ids.append(segment_id)
        hop_start_weights.append(hop_start_weight)
        hop_end_weights.append(hop_end_weight)
        ans_start_weights.append(ans_start_weight)
        ans_end_weights.append(ans_end_weight)

    n = len(context)
    edges_in_bundle = []
    if question_type == 0:
        # find all edges and prepare forbidden set(containing answer) for negative sampling
        forbidden = set([])
        for para, sen, edges in data['supporting_facts']:
            for x, matched, l, r in edges:
                edges_in_bundle.append((e2i[para], e2i[x]))
                if x == answer_entity:
                    forbidden.add((para, sen))
        if answer_entity not in context and answer_entity in e2i:
            n += 1
            tokenized_all = tokenized_question + prev[e2i[answer_entity]]
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('ANSWER TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(
                tokenizer.convert_tokens_to_ids(tokenized_all))
        for i in range(neg):
            # build negative node n+i
            father_para = random.choice(list(context.keys()))
            father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                father_para = random.choice(list(context.keys()))
                father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                neg -= 1
                continue
            tokenized_all = tokenized_question + tokenizer.tokenize(
                context[father_para][father_sen]) + ['[SEP]']
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('NEG TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(
                tokenizer.convert_tokens_to_ids(tokenized_all))
            edges_in_bundle.append((e2i[father_para], n + i))
        n += neg
    assert n == len(additional_nodes) + len(context)

    adj = torch.eye(n) * 2
    for x, y in edges_in_bundle:
        adj[x, y] = 1
    adj /= torch.sum(adj, dim=0, keepdim=True)

    _id = data['_id']
    ret = Bundle()
    for field in FIELDS:
        setattr(ret, field, eval(field))
    return ret