def improve_question_type_and_answer(data, e2i): '''Improve the result of the judgement of question type in training data with other information. If the question is a special question(type 0), answer_id is the index of final answer node. Otherwise answer_ids are the indices of two compared nodes and the result of comparison(0 / 1). This part is not very important to the overall results, but avoids Runtime Errors in rare cases. Args: data (Json): Refined distractor-setting samples. e2i (dict): entity2index dict. Returns: (int, int or (int, int, 0 / 1), string): question_type, answer_id and answer_entity. ''' question_type = judge_question_type(data['question']) # fix judgement by answer if data['answer'] == 'yes' or data['answer'] == 'no': question_type = 1 answer_entity = data['answer'] else: # check whether the answer can be extracted as a span answer_entity = fuzzy_retrieve(data['answer'], e2i, 'distractor', 80) if answer_entity is None: raise ValueError('Cannot find answer: {}'.format(data['answer'])) if question_type == 0: answer_id = e2i[answer_entity] elif len(data['Q_edge']) != 2: if question_type == 1: raise ValueError( 'There must be 2 entities in "Q_edge" for type 1 question.') elif question_type == 2: # Judgement error, should be type 0 question_type = 0 answer_id = e2i[answer_entity] else: answer_id = [e2i[data['Q_edge'][0][0]], e2i[data['Q_edge'][1][0]]] # compared nodes if question_type == 1: answer_id.append(int(data['answer'] == 'yes')) elif question_type == 2: if data['answer'] == data['Q_edge'][0][1]: answer_id.append(0) elif data['answer'] == data['Q_edge'][1][1]: answer_id.append(1) else: # cannot exactly match an option score = (fuzz.partial_ratio(data['answer'], data['Q_edge'][0][1]), fuzz.partial_ratio(data['answer'], data['Q_edge'][1][1])) if score[0] < 50 and score[1] < 50: raise ValueError( 'There is no exact match in selecting question. answer: {}' .format(data['answer'])) else: answer_id.append(0 if score[0] > score[1] else 1) return question_type, answer_id, answer_entity
def convert_question_to_samples_bundle(tokenizer, data: 'Json refined', neg=2): context = dict(data['context']) gold_sentences_set = dict([ ((para, sen), edges) for para, sen, edges in data['supporting_facts'] ]) e2i = {} i2e = [] for entity, sens in context.items(): assert not entity in e2i e2i[entity] = len(i2e) i2e.append(entity) prev = [[]] * len(i2e) ids, hop_start_weights, hop_end_weights, ans_start_weights, ans_end_weights, segment_ids, sep_positions, additional_nodes = [], [], [], [], [], [], [], [] tokenized_question = ['[CLS]'] + tokenizer.tokenize( data['question']) + ['[SEP]'] for title_x, sen, edges in data[ 'supporting_facts']: # TODO: match previous sentence for title_y, matched, l, r in edges: if title_y not in e2i: # answer assert data['answer'] == title_y e2i[title_y] = len(i2e) i2e.append(title_y) prev.append([]) if title_x != title_y: y = e2i[title_y] prev[y] = prev[y] + tokenizer.tokenize( context[title_x][sen]) + ['[SEP]'] question_type = judge_question_type(data['question']) # fix by answer: if data['answer'] == 'yes' or data['answer'] == 'no': question_type = 1 answer_entity = data['answer'] else: # find answer entity answer_entity = fuzzy_retrieve(data['answer'], e2i, 'distractor', 80) if answer_entity is None: raise ValueError('Cannot find answer: {}'.format(data['answer'])) if question_type == 0: answer_id = e2i[answer_entity] elif len(data['Q_edge']) != 2: if question_type == 1: raise ValueError( 'There must be 2 entities in "Q_edge" for type 1 question.') elif question_type == 2: # print('Convert type 2 question to 0.\n Question:{}'.format(data['question'])) question_type = 0 answer_id = e2i[answer_entity] else: answer_id = [e2i[data['Q_edge'][0][0]], e2i[data['Q_edge'][1][0]]] if question_type == 1: answer_id.append(int(data['answer'] == 'yes')) elif question_type == 2: if data['answer'] == data['Q_edge'][0][1]: answer_id.append(0) elif data['answer'] == data['Q_edge'][1][1]: answer_id.append(1) else: score = (fuzz.partial_ratio(data['answer'], data['Q_edge'][0][1]), fuzz.partial_ratio(data['answer'], data['Q_edge'][1][1])) if score[0] < 50 and score[1] < 50: raise ValueError( 'There is no exact match in selecting question. answer: {}' .format(data['answer'])) else: # print('Resolve type 1 or 2 question: {}\n answer: {}'.format(data['question'], data['answer'])) answer_id.append(0 if score[0] > score[1] else 1) else: pass for entity, sens in context.items(): num_hop, num_ans = 0, 0 tokenized_all = tokenized_question + prev[e2i[entity]] if len(tokenized_all) > 512: tokenized_all = tokenized_all[:512] print('PREV TOO LONG, id: {}'.format(data['_id'])) segment_id = [0] * len(tokenized_all) sep_position = [] hop_start_weight = [0] * len(tokenized_all) hop_end_weight = [0] * len(tokenized_all) ans_start_weight = [0] * len(tokenized_all) ans_end_weight = [0] * len(tokenized_all) for sen_num, sen in enumerate(sens): tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]'] if len(tokenized_all) + len(tokenized_sen) > 512 or sen_num > 15: break # if sen_num > 10: # raise ValueError('Too many sentences in context: {}'.format(sens)) tokenized_all += tokenized_sen segment_id += [sen_num + 1] * len(tokenized_sen) sep_position.append(len(tokenized_all) - 1) hs_weight = [0] * len(tokenized_sen) he_weight = [0] * len(tokenized_sen) as_weight = [0] * len(tokenized_sen) ae_weight = [0] * len(tokenized_sen) if (entity, sen_num) in gold_sentences_set: tmp = gold_sentences_set[(entity, sen_num)] intervals = find_start_end_after_tokenized( tokenizer, tokenized_sen, [matched for _, matched, _, _ in tmp]) for j, (l, r) in enumerate(intervals): if tmp[j][0] == answer_entity or question_type > 0: as_weight[l] = ae_weight[r] = 1 num_ans += 1 else: hs_weight[l] = he_weight[r] = 1 num_hop += 1 hop_start_weight += hs_weight hop_end_weight += he_weight ans_start_weight += as_weight ans_end_weight += ae_weight assert len(tokenized_all) <= 512 # for i in range(len(start_weight)): # start_weight[i] /= max(num_spans, 1) # end_weight[i] /= max(num_spans, 1) if 1 not in hop_start_weight: hop_start_weight[0] = 0.1 if 1 not in ans_start_weight: ans_start_weight[0] = 0.1 ids.append(tokenizer.convert_tokens_to_ids(tokenized_all)) sep_positions.append(sep_position) segment_ids.append(segment_id) hop_start_weights.append(hop_start_weight) hop_end_weights.append(hop_end_weight) ans_start_weights.append(ans_start_weight) ans_end_weights.append(ans_end_weight) n = len(context) edges_in_bundle = [] if question_type == 0: # find all edges and prepare forbidden set(containing answer) for negative sampling forbidden = set([]) for para, sen, edges in data['supporting_facts']: for x, matched, l, r in edges: edges_in_bundle.append((e2i[para], e2i[x])) if x == answer_entity: forbidden.add((para, sen)) if answer_entity not in context and answer_entity in e2i: n += 1 tokenized_all = tokenized_question + prev[e2i[answer_entity]] if len(tokenized_all) > 512: tokenized_all = tokenized_all[:512] print('ANSWER TOO LONG! id: {}'.format(data['_id'])) additional_nodes.append( tokenizer.convert_tokens_to_ids(tokenized_all)) for i in range(neg): # build negative node n+i father_para = random.choice(list(context.keys())) father_sen = random.randrange(len(context[father_para])) if (father_para, father_sen) in forbidden: father_para = random.choice(list(context.keys())) father_sen = random.randrange(len(context[father_para])) if (father_para, father_sen) in forbidden: neg -= 1 continue tokenized_all = tokenized_question + tokenizer.tokenize( context[father_para][father_sen]) + ['[SEP]'] if len(tokenized_all) > 512: tokenized_all = tokenized_all[:512] print('NEG TOO LONG! id: {}'.format(data['_id'])) additional_nodes.append( tokenizer.convert_tokens_to_ids(tokenized_all)) edges_in_bundle.append((e2i[father_para], n + i)) n += neg assert n == len(additional_nodes) + len(context) adj = torch.eye(n) * 2 for x, y in edges_in_bundle: adj[x, y] = 1 adj /= torch.sum(adj, dim=0, keepdim=True) _id = data['_id'] ret = Bundle() for field in FIELDS: setattr(ret, field, eval(field)) return ret