def improve_question_type_and_answer(data, e2i): '''Improve the result of the judgement of question type in training data with other information. If the question is a special question(type 0), answer_id is the index of final answer node. Otherwise answer_ids are the indices of two compared nodes and the result of comparison(0 / 1). This part is not very important to the overall results, but avoids Runtime Errors in rare cases. Args: data (Json): Refined distractor-setting samples. e2i (dict): entity2index dict. Returns: (int, int or (int, int, 0 / 1), string): question_type, answer_id and answer_entity. ''' question_type = judge_question_type(data['question']) # fix judgement by answer if data['answer'] == 'yes' or data['answer'] == 'no': question_type = 1 answer_entity = data['answer'] else: # check whether the answer can be extracted as a span answer_entity = fuzzy_retrieve(data['answer'], e2i, 'distractor', 80) if answer_entity is None: raise ValueError('Cannot find answer: {}'.format(data['answer'])) if question_type == 0: answer_id = e2i[answer_entity] elif len(data['Q_edge']) != 2: if question_type == 1: raise ValueError( 'There must be 2 entities in "Q_edge" for type 1 question.') elif question_type == 2: # Judgement error, should be type 0 question_type = 0 answer_id = e2i[answer_entity] else: answer_id = [e2i[data['Q_edge'][0][0]], e2i[data['Q_edge'][1][0]]] # compared nodes if question_type == 1: answer_id.append(int(data['answer'] == 'yes')) elif question_type == 2: if data['answer'] == data['Q_edge'][0][1]: answer_id.append(0) elif data['answer'] == data['Q_edge'][1][1]: answer_id.append(1) else: # cannot exactly match an option score = (fuzz.partial_ratio(data['answer'], data['Q_edge'][0][1]), fuzz.partial_ratio(data['answer'], data['Q_edge'][1][1])) if score[0] < 50 and score[1] < 50: raise ValueError( 'There is no exact match in selecting question. answer: {}' .format(data['answer'])) else: answer_id.append(0 if score[0] > score[1] else 1) return question_type, answer_id, answer_entity
def convert_question_to_samples_bundle(tokenizer, data: 'Json refined', neg=2): context = dict(data['context']) gold_sentences_set = dict([ ((para, sen), edges) for para, sen, edges in data['supporting_facts'] ]) e2i = {} i2e = [] for entity, sens in context.items(): assert not entity in e2i e2i[entity] = len(i2e) i2e.append(entity) prev = [[]] * len(i2e) ids, hop_start_weights, hop_end_weights, ans_start_weights, ans_end_weights, segment_ids, sep_positions, additional_nodes = [], [], [], [], [], [], [], [] tokenized_question = ['[CLS]'] + tokenizer.tokenize( data['question']) + ['[SEP]'] for title_x, sen, edges in data[ 'supporting_facts']: # TODO: match previous sentence for title_y, matched, l, r in edges: if title_y not in e2i: # answer assert data['answer'] == title_y e2i[title_y] = len(i2e) i2e.append(title_y) prev.append([]) if title_x != title_y: y = e2i[title_y] prev[y] = prev[y] + tokenizer.tokenize( context[title_x][sen]) + ['[SEP]'] question_type = judge_question_type(data['question']) # fix by answer: if data['answer'] == 'yes' or data['answer'] == 'no': question_type = 1 answer_entity = data['answer'] else: # find answer entity answer_entity = fuzzy_retrieve(data['answer'], e2i, 'distractor', 80) if answer_entity is None: raise ValueError('Cannot find answer: {}'.format(data['answer'])) if question_type == 0: answer_id = e2i[answer_entity] elif len(data['Q_edge']) != 2: if question_type == 1: raise ValueError( 'There must be 2 entities in "Q_edge" for type 1 question.') elif question_type == 2: # print('Convert type 2 question to 0.\n Question:{}'.format(data['question'])) question_type = 0 answer_id = e2i[answer_entity] else: answer_id = [e2i[data['Q_edge'][0][0]], e2i[data['Q_edge'][1][0]]] if question_type == 1: answer_id.append(int(data['answer'] == 'yes')) elif question_type == 2: if data['answer'] == data['Q_edge'][0][1]: answer_id.append(0) elif data['answer'] == data['Q_edge'][1][1]: answer_id.append(1) else: score = (fuzz.partial_ratio(data['answer'], data['Q_edge'][0][1]), fuzz.partial_ratio(data['answer'], data['Q_edge'][1][1])) if score[0] < 50 and score[1] < 50: raise ValueError( 'There is no exact match in selecting question. answer: {}' .format(data['answer'])) else: # print('Resolve type 1 or 2 question: {}\n answer: {}'.format(data['question'], data['answer'])) answer_id.append(0 if score[0] > score[1] else 1) else: pass for entity, sens in context.items(): num_hop, num_ans = 0, 0 tokenized_all = tokenized_question + prev[e2i[entity]] if len(tokenized_all) > 512: tokenized_all = tokenized_all[:512] print('PREV TOO LONG, id: {}'.format(data['_id'])) segment_id = [0] * len(tokenized_all) sep_position = [] hop_start_weight = [0] * len(tokenized_all) hop_end_weight = [0] * len(tokenized_all) ans_start_weight = [0] * len(tokenized_all) ans_end_weight = [0] * len(tokenized_all) for sen_num, sen in enumerate(sens): tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]'] if len(tokenized_all) + len(tokenized_sen) > 512 or sen_num > 15: break # if sen_num > 10: # raise ValueError('Too many sentences in context: {}'.format(sens)) tokenized_all += tokenized_sen segment_id += [sen_num + 1] * len(tokenized_sen) sep_position.append(len(tokenized_all) - 1) hs_weight = [0] * len(tokenized_sen) he_weight = [0] * len(tokenized_sen) as_weight = [0] * len(tokenized_sen) ae_weight = [0] * len(tokenized_sen) if (entity, sen_num) in gold_sentences_set: tmp = gold_sentences_set[(entity, sen_num)] intervals = find_start_end_after_tokenized( tokenizer, tokenized_sen, [matched for _, matched, _, _ in tmp]) for j, (l, r) in enumerate(intervals): if tmp[j][0] == answer_entity or question_type > 0: as_weight[l] = ae_weight[r] = 1 num_ans += 1 else: hs_weight[l] = he_weight[r] = 1 num_hop += 1 hop_start_weight += hs_weight hop_end_weight += he_weight ans_start_weight += as_weight ans_end_weight += ae_weight assert len(tokenized_all) <= 512 # for i in range(len(start_weight)): # start_weight[i] /= max(num_spans, 1) # end_weight[i] /= max(num_spans, 1) if 1 not in hop_start_weight: hop_start_weight[0] = 0.1 if 1 not in ans_start_weight: ans_start_weight[0] = 0.1 ids.append(tokenizer.convert_tokens_to_ids(tokenized_all)) sep_positions.append(sep_position) segment_ids.append(segment_id) hop_start_weights.append(hop_start_weight) hop_end_weights.append(hop_end_weight) ans_start_weights.append(ans_start_weight) ans_end_weights.append(ans_end_weight) n = len(context) edges_in_bundle = [] if question_type == 0: # find all edges and prepare forbidden set(containing answer) for negative sampling forbidden = set([]) for para, sen, edges in data['supporting_facts']: for x, matched, l, r in edges: edges_in_bundle.append((e2i[para], e2i[x])) if x == answer_entity: forbidden.add((para, sen)) if answer_entity not in context and answer_entity in e2i: n += 1 tokenized_all = tokenized_question + prev[e2i[answer_entity]] if len(tokenized_all) > 512: tokenized_all = tokenized_all[:512] print('ANSWER TOO LONG! id: {}'.format(data['_id'])) additional_nodes.append( tokenizer.convert_tokens_to_ids(tokenized_all)) for i in range(neg): # build negative node n+i father_para = random.choice(list(context.keys())) father_sen = random.randrange(len(context[father_para])) if (father_para, father_sen) in forbidden: father_para = random.choice(list(context.keys())) father_sen = random.randrange(len(context[father_para])) if (father_para, father_sen) in forbidden: neg -= 1 continue tokenized_all = tokenized_question + tokenizer.tokenize( context[father_para][father_sen]) + ['[SEP]'] if len(tokenized_all) > 512: tokenized_all = tokenized_all[:512] print('NEG TOO LONG! id: {}'.format(data['_id'])) additional_nodes.append( tokenizer.convert_tokens_to_ids(tokenized_all)) edges_in_bundle.append((e2i[father_para], n + i)) n += neg assert n == len(additional_nodes) + len(context) adj = torch.eye(n) * 2 for x, y in edges_in_bundle: adj[x, y] = 1 adj /= torch.sum(adj, dim=0, keepdim=True) _id = data['_id'] ret = Bundle() for field in FIELDS: setattr(ret, field, eval(field)) return ret
def cognitive_graph_propagate(tokenizer, data: 'Json eval(Context as pool)', model1, model2, device, setting: 'distractor / fullwiki' = 'fullwiki', max_new_nodes=5): """Answer the question in ``data'' by trained CogQA model. Args: tokenizer (Tokenizer): Word-Piece tokenizer. data (Json): Unrefined. model1 (nn.Module): System 1 model. model2 (nn.Module): System 2 model. device (torch.device): Selected device. setting (string, optional): 'distractor / fullwiki'. Defaults to 'fullwiki'. max_new_nodes (int, optional): Maximum number of new nodes in cognitive graph. Defaults to 5. Returns: tuple: (gold_ret, ans_ret, graph_ret, ans_nodes_ret) """ context = dict(data['context']) e2i = dict([(entity, id) for id, entity in enumerate(context.keys())]) n = len(context) i2e = [''] * n for k, v in e2i.items(): i2e[v] = k prev = [[] for i in range(n)] # elements: (title, sen_num) queue = range(n) semantics = [None] * n input_masks = [None] * n tokenized_question = ['[CLS]'] + tokenizer.tokenize( data['question']) + ['[SEP]'] def construct_infer_batch(queue): """Construct next batch (frontier nodes to visit). Args: queue (list): A queue containing frontier nodes. Returns: tuple: A batch of inputs """ ids, sep_positions, segment_ids, tokenized_alls, B_starts = [], [], [], [], [] max_length, max_seps, num_samples = 0, 0, len(queue) for x in queue: tokenized_all = copy.copy(tokenized_question) for title, sen_num in prev[x]: tokenized_all += tokenizer.tokenize( context[title][sen_num]) + ['[SEP]'] if len(tokenized_all) > 512: tokenized_all = tokenized_all[:512] print('PREV TOO LONG, id: {}'.format(data['_id'])) segment_id = [0] * len(tokenized_all) sep_position = [] B_starts.append(len(tokenized_all)) for sen_num, sen in enumerate(context[i2e[x]]): tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]'] if len(tokenized_all) + len( tokenized_sen) > 512 or sen_num > 15: break tokenized_all += tokenized_sen segment_id += [sen_num + 1] * len(tokenized_sen) sep_position.append(len(tokenized_all) - 1) max_length = max(max_length, len(tokenized_all)) max_seps = max(max_seps, len(sep_position)) tokenized_alls.append(tokenized_all) ids.append(tokenizer.convert_tokens_to_ids(tokenized_all)) sep_positions.append(sep_position) segment_ids.append(segment_id) ids_tensor = torch.zeros((num_samples, max_length), dtype=torch.long, device=device) sep_positions_tensor = torch.zeros((num_samples, max_seps), dtype=torch.long, device=device) segment_ids_tensor = torch.zeros((num_samples, max_length), dtype=torch.long, device=device) input_mask = torch.zeros((num_samples, max_length), dtype=torch.long, device=device) B_starts = torch.tensor(B_starts, dtype=torch.long, device=device) for i in range(num_samples): length = len(ids[i]) ids_tensor[i, :length] = torch.tensor(ids[i], dtype=torch.long) sep_positions_tensor[i, :len(sep_positions[i])] = torch.tensor( sep_positions[i], dtype=torch.long) segment_ids_tensor[i, :length] = torch.tensor(segment_ids[i], dtype=torch.long) input_mask[i, :length] = 1 return ids_tensor, segment_ids_tensor, input_mask, sep_positions_tensor, tokenized_alls, B_starts gold_ret, ans_nodes = set([]), set([]) allow_limit = [0, 0] while len(queue) > 0: # visit all nodes in the frontier queue ids, segment_ids, input_mask, sep_positions, tokenized_alls, B_starts = construct_infer_batch( queue) hop_preds, ans_preds, semantics_preds, no_ans_logits = model1( ids, segment_ids, input_mask, sep_positions, None, None, None, None, B_starts, allow_limit) new_queue = [] assert len(queue) == input_mask.shape[0] for i, x in enumerate(queue): input_masks[x] = input_mask[i] semantics[x] = semantics_preds[i] # for hop spans for k in range(hop_preds.size()[1]): l, r, j = hop_preds[i, k] j = j.item() if l == 0: break gold_ret.add((i2e[x], j)) # supporting facts orig_text = context[i2e[x]][j] pred_slice = tokenized_alls[i][l:r + 1] l, r = find_start_end_before_tokenized(orig_text, [pred_slice])[0] if l == r == 0: continue recovered_matched = orig_text[l:r] pool = context if setting == 'distractor' else (i2e[x], j) matched = fuzzy_retrieve(recovered_matched, pool, setting) if matched is not None: if setting == 'fullwiki' and matched not in e2i and n < 10 + max_new_nodes: context_new = get_context_fullwiki(matched) if len(context_new) > 0: # cannot resovle redirection # create new nodes in the cognitive graph context[matched] = context_new prev.append([]) semantics.append(None) input_masks.append(None) e2i[matched] = n i2e.append(matched) n += 1 if matched in e2i and e2i[matched] != x: y = e2i[matched] if y not in new_queue and (i2e[x], j) not in prev[y]: # new edge means new clues! update the successor as frontier nodes. new_queue.append(y) prev[y].append(((i2e[x], j))) # for ans spans for k in range(ans_preds.size()[1]): l, r, j = ans_preds[i, k] j = j.item() if l == 0: break gold_ret.add((i2e[x], j)) orig_text = context[i2e[x]][j] pred_slice = tokenized_alls[i][l:r + 1] l, r = find_start_end_before_tokenized(orig_text, [pred_slice])[0] if l == r == 0: continue recovered_matched = orig_text[l:r] matched = fuzzy_retrieve(recovered_matched, context, 'distractor', threshold=70) if matched is not None: y = e2i[matched] ans_nodes.add(y) if (i2e[x], j) not in prev[y]: prev[y].append(((i2e[x], j))) elif n < 10 + max_new_nodes: context[recovered_matched] = [] e2i[recovered_matched] = n i2e.append(recovered_matched) new_queue.append(n) ans_nodes.add(n) prev.append([(i2e[x], j)]) semantics.append(None) input_masks.append(None) n += 1 if len(new_queue) == 0 and len(ans_nodes) == 0 and allow_limit[ 1] < 0.1: # must find one answer # ``allow'' is an offset of negative threshold. # If no ans span is valid, make the minimal gap between negative threshold and probability of ans spans -0.1, and try again. prob, pos_in_queue = torch.min(no_ans_logits, dim=0) new_queue.append(queue[pos_in_queue]) allow_limit[1] = prob.item() + 0.1 queue = new_queue question_type = judge_question_type(data['question']) if n == 0: return set([]), 'yes', [], [] if n == 1 and question_type > 0: ans_ret = 'yes' if question_type == 1 else i2e[0] return [(i2e[0], 0)], ans_ret, [], [] # GCN || CompareNets seq_len = np.max([x.shape[0] for x in semantics]) for idx in range(len(semantics)): if semantics[idx].shape[0] < seq_len: semantics[idx] = torch.cat( (semantics[idx], torch.zeros(seq_len - semantics[idx].shape[0], semantics[idx].shape[1]).to(device)), dim=0) seq_len = np.max([x.shape[0] for x in input_masks]) for idx in range(len(input_masks)): input_masks[idx] = torch.cat( (input_masks[idx], torch.zeros(seq_len - input_masks[idx].shape[0], dtype=torch.long).to(device)), dim=0) semantics = torch.stack(semantics) input_masks = torch.stack(input_masks) input_masks = input_masks.unsqueeze(1).unsqueeze(2) input_masks = (1.0 - input_masks) * -10000.0 input_masks = input_masks.to(dtype=torch.float32) # fp16 compatibility if question_type == 0: adj = torch.eye(n, device=device) * 2 for x in range(n): for title, sen_num in prev[x]: adj[e2i[title], x] = 1 adj /= torch.sum(adj, dim=0, keepdim=True) pred = model2.gcn(adj, semantics, input_masks) for x in range(n): if x not in ans_nodes: pred[x] -= 10000. ans_ret = i2e[torch.argmax(pred).item()] else: # Take the most golden paragraphs as x,y gold_num = torch.zeros(n) for title, sen_num in gold_ret: gold_num[e2i[title]] += 1 x, y = gold_num.topk(2)[1].tolist() diff_sem = semantics[x][0] - semantics[y][0] classifier = model2.both_net if question_type == 1 else model2.select_net pred = int(torch.sigmoid(classifier(diff_sem)).item() > 0.5) ans_ret = ['no', 'yes' ][pred] if question_type == 1 else [i2e[x], i2e[y]][pred] ans_ret = re.sub(r' \(.*?\)$', '', ans_ret) graph_ret = [] for x in range(n): for title, sen_num in prev[x]: graph_ret.append('({}, {}) --> {}'.format(title, sen_num, i2e[x])) ans_nodes_ret = [i2e[x] for x in ans_nodes] return gold_ret, ans_ret, graph_ret, ans_nodes_ret
def cognitive_graph_propagate(tokenizer, data: 'Json eval(Context as pool)', model, model_cg, device, setting: 'distractor / fullwiki' = 'distractor', max_new_nodes=5): context = dict(data['context']) e2i = dict([(entity, id) for id, entity in enumerate(context.keys())]) n = len(context) i2e = [''] * n for k, v in e2i.items(): i2e[v] = k prev = [[] for i in range(n)] # elements: (title, sen_num) queue = range(n) semantics = [None] * n tokenized_question = ['[CLS]'] + tokenizer.tokenize( data['question']) + ['[SEP]'] def construct_infer_batch(queue): ids, sep_positions, segment_ids, tokenized_alls, B_starts = [], [], [], [], [] max_length, max_seps, num_samples = 0, 0, len(queue) for x in queue: tokenized_all = copy.copy(tokenized_question) for title, sen_num in prev[x]: tokenized_all += tokenizer.tokenize( context[title][sen_num]) + ['[SEP]'] if len(tokenized_all) > 512: tokenized_all = tokenized_all[:512] print('PREV TOO LONG, id: {}'.format(data['_id'])) segment_id = [0] * len(tokenized_all) sep_position = [] B_starts.append(len(tokenized_all)) for sen_num, sen in enumerate(context[i2e[x]]): tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]'] if len(tokenized_all) + len( tokenized_sen) > 512 or sen_num > 15: break tokenized_all += tokenized_sen segment_id += [sen_num + 1] * len(tokenized_sen) sep_position.append(len(tokenized_all) - 1) max_length = max(max_length, len(tokenized_all)) max_seps = max(max_seps, len(sep_position)) tokenized_alls.append(tokenized_all) ids.append(tokenizer.convert_tokens_to_ids(tokenized_all)) sep_positions.append(sep_position) segment_ids.append(segment_id) ids_tensor = torch.zeros((num_samples, max_length), dtype=torch.long, device=device) sep_positions_tensor = torch.zeros((num_samples, max_seps), dtype=torch.long, device=device) segment_ids_tensor = torch.zeros((num_samples, max_length), dtype=torch.long, device=device) input_mask = torch.zeros((num_samples, max_length), dtype=torch.long, device=device) B_starts = torch.tensor(B_starts, dtype=torch.long, device=device) for i in range(num_samples): length = len(ids[i]) ids_tensor[i, :length] = torch.tensor(ids[i], dtype=torch.long) sep_positions_tensor[i, :len(sep_positions[i])] = torch.tensor( sep_positions[i], dtype=torch.long) segment_ids_tensor[i, :length] = torch.tensor(segment_ids[i], dtype=torch.long) input_mask[i, :length] = 1 return ids_tensor, segment_ids_tensor, input_mask, sep_positions_tensor, tokenized_alls, B_starts gold_ret, ans_nodes = set([]), set([]) allow_limit = [0, 0] while len(queue) > 0: ids, segment_ids, input_mask, sep_positions, tokenized_alls, B_starts = construct_infer_batch( queue) # pdb.set_trace() hop_preds, ans_preds, semantics_preds, no_ans_logits = model( ids, segment_ids, input_mask, sep_positions, None, None, None, None, B_starts, allow_limit) new_queue = [] for i, x in enumerate(queue): semantics[x] = semantics_preds[i] # for hop spans for k in range(hop_preds.size()[1]): l, r, j = hop_preds[i, k] j = j.item() if l == 0: break gold_ret.add((i2e[x], j)) orig_text = context[i2e[x]][j] pred_slice = tokenized_alls[i][l:r + 1] l, r = find_start_end_before_tokenized(orig_text, [pred_slice])[0] if l == r == 0: continue recovered_matched = orig_text[l:r] pool = context if setting == 'distractor' else (i2e[x], j) matched = fuzzy_retrieve(recovered_matched, pool, setting) if matched is not None: if setting == 'fullwiki' and matched not in e2i and n < 10 + max_new_nodes: context_new = get_context_fullwiki(matched) if len(context_new) > 0: # cannot resovle redirection context[matched] = context_new prev.append([]) semantics.append(None) e2i[matched] = n i2e.append(matched) n += 1 if matched in e2i and e2i[matched] != x: y = e2i[matched] if y not in new_queue and (i2e[x], j) not in prev[y]: new_queue.append(y) prev[y].append(((i2e[x], j))) # for ans spans for k in range(ans_preds.size()[1]): l, r, j = ans_preds[i, k] j = j.item() if l == 0: break gold_ret.add((i2e[x], j)) orig_text = context[i2e[x]][j] pred_slice = tokenized_alls[i][l:r + 1] l, r = find_start_end_before_tokenized(orig_text, [pred_slice])[0] if l == r == 0: continue recovered_matched = orig_text[l:r] # pool = context if setting == 'distractor' else (i2e[x], j) matched = fuzzy_retrieve(recovered_matched, context, 'distractor', threshold=70) if matched is not None: y = e2i[matched] ans_nodes.add(y) if (i2e[x], j) not in prev[y]: prev[y].append(((i2e[x], j))) elif n < 10 + max_new_nodes: context[recovered_matched] = [] e2i[recovered_matched] = n i2e.append(recovered_matched) new_queue.append(n) ans_nodes.add(n) prev.append([(i2e[x], j)]) semantics.append(None) n += 1 if len(new_queue) == 0 and len(ans_nodes) == 0 and allow_limit[ 1] < 0.1: # must find one answer prob, pos_in_queue = torch.min(no_ans_logits, dim=0) new_queue.append(queue[pos_in_queue]) allow_limit[1] = prob.item() + 0.1 queue = new_queue question_type = judge_question_type(data['question']) if n == 0: return set([]), 'yes', [], [] if n == 1 and question_type > 0: ans_ret = 'yes' if question_type == 1 else i2e[0] return [(i2e[0], 0)], ans_ret, [], [] # GCN || CompareNets semantics = torch.stack(semantics) if question_type == 0: adj = torch.eye(n, device=device) * 2 for x in range(n): for title, sen_num in prev[x]: adj[e2i[title], x] = 1 adj /= torch.sum(adj, dim=0, keepdim=True) pred = model_cg.gcn(adj, semantics) for x in range(n): if x not in ans_nodes: pred[x] -= 10000. ans_ret = i2e[torch.argmax(pred).item()] else: # Take the most golden paragraphs as x,y gold_num = torch.zeros(n) for title, sen_num in gold_ret: gold_num[e2i[title]] += 1 x, y = gold_num.topk(2)[1].tolist() diff_sem = semantics[x] - semantics[y] classifier = model_cg.both_net if question_type == 1 else model_cg.select_net pred = int(classifier(diff_sem).item() > 0.5) ans_ret = ['no', 'yes' ][pred] if question_type == 1 else [i2e[x], i2e[y]][pred] ans_ret = re.sub(r' \(.*?\)$', '', ans_ret) graph_ret = [] for x in range(n): for title, sen_num in prev[x]: graph_ret.append('({}, {}) --> {}'.format(title, sen_num, i2e[x])) ans_nodes_ret = [i2e[x] for x in ans_nodes] return gold_ret, ans_ret, graph_ret, ans_nodes_ret