def create_examples(jsn, graph_retriever_config): task = graph_retriever_config.task examples = [] ''' Find the mximum size of the initial context (links are not included) ''' graph_retriever_config.max_context_size = 0 for data in jsn: guid = data['q_id'] question = data['question'] context = data['context'] # {context title: paragraph} all_linked_paras_dic = {} # {context title: {linked title: paragraph}} ''' Use TagMe-based context at test time. ''' if graph_retriever_config.tagme: assert 'tagged_context' in data ''' Reformat "tagged_context" if needed (c.f. the "context" case above) ''' if type(data['tagged_context']) == list: tagged_context = {c[0]: c[1] for c in data['tagged_context']} data['tagged_context'] = tagged_context ''' Append valid paragraphs from "tagged_context" to "context" ''' for tagged_title in data['tagged_context']: tagged_text = data['tagged_context'][tagged_title] if tagged_title not in context and tagged_title is not None and tagged_title.strip( ) != '' and tagged_text is not None and tagged_text.strip( ) != '': context[tagged_title] = tagged_text ''' Clean "context" by removing invalid paragraphs ''' removed_keys = [] for title in context: if title is None or title.strip( ) == '' or context[title] is None or context[title].strip() == '': removed_keys.append(title) for key in removed_keys: context.pop(key) all_paras = {} for title in context: all_paras[title] = context[title] if graph_retriever_config.expand_links: expand_links(context, all_linked_paras_dic, all_paras) graph_retriever_config.max_context_size = max( graph_retriever_config.max_context_size, len(context)) examples.append( InputExample(guid=guid, q=question, c=context, para_dic=all_linked_paras_dic, s_g=None, r_g=None, all_r_g=None, all_paras=all_paras)) return examples
def beam_search(self, input_ids, token_type_ids, attention_mask, examples, tokenizer, retriever, split_chunk): beam = self.graph_retriever_config.beam B = input_ids.size(0) N = self.graph_retriever_config.max_para_num pred = [] prob = [] topk_pred = [] topk_prob = [] eos_index = N init_paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask, split_chunk=split_chunk) # Output matrix to be populated ps = torch.FloatTensor(N + 1, self.s.size(0)).zero_().to( self.s.device) # (N+1, D) for i in range(B): init_context_len = len(examples[i].context) # Populating the output matrix by the initial encoding ps[:init_context_len, :].copy_( init_paragraphs[i, :init_context_len, :]) ps[-1, :].copy_(init_paragraphs[i, -1, :]) encoded_titles = set(examples[i].title_order) pred_ = [ [[], [], 1.0] for _ in range(beam) ] # [hist_1, topk_1, score_1], [hist_2, topk_2, score_2], ... prob_ = [[] for _ in range(beam)] state_ = state[i:i + 1] # (1, 1, D) state_ = state_.expand(beam, 1, state_.size(2)) # -> (beam, 1, D) state_tmp = torch.FloatTensor(state_.size()).zero_().to( state_.device) for j in range(self.graph_retriever_config.max_select_num): if j > 0: input = [p[0][-1] for p in pred_] input = torch.LongTensor(input).to(ps.device) input = ps[input].unsqueeze(1) # (beam, 1, D) state_ = torch.cat((state_, input), dim=2) # (beam, 1, 2*D) state_ = self.rw(state_) # (beam, 1, D) state_ = self.weight_norm(state_) # Opening new links from the previous predictions (pupulating the output matrix dynamically) if j > 0: prev_title_size = len(examples[i].title_order) new_titles = [] for b in range(beam): prev_pred = pred_[b][0][-1] if prev_pred == eos_index: continue prev_title = examples[i].title_order[prev_pred] if prev_title not in examples[i].all_linked_paras_dic: if retriever is None: continue else: linked_paras_dic = retriever.get_hyperlinked_abstract_paragraphs( prev_title, examples[i].question) examples[i].all_linked_paras_dic[ prev_title] = {} examples[i].all_linked_paras_dic[ prev_title].update(linked_paras_dic) examples[i].all_paras.update(linked_paras_dic) for linked_title in examples[i].all_linked_paras_dic[ prev_title]: if linked_title in encoded_titles or len( examples[i].title_order) == N: continue encoded_titles.add(linked_title) new_titles.append(linked_title) examples[i].title_order.append(linked_title) if len(new_titles) > 0: tokens_q = tokenize_question(examples[i].question, tokenizer) input_ids = [] input_masks = [] segment_ids = [] for linked_title in new_titles: linked_para = examples[i].all_paras[linked_title] input_ids_, input_masks_, segment_ids_ = tokenize_paragraph( linked_para, tokens_q, self.graph_retriever_config.max_seq_length, tokenizer) input_ids.append(input_ids_) input_masks.append(input_masks_) segment_ids.append(segment_ids_) input_ids = torch.LongTensor([input_ids]).to(ps.device) token_type_ids = torch.LongTensor([segment_ids ]).to(ps.device) attention_mask = torch.LongTensor([input_masks ]).to(ps.device) paragraphs, _ = self.encode(input_ids, token_type_ids, attention_mask, split_chunk=split_chunk) paragraphs = paragraphs.squeeze(0) ps[prev_title_size:prev_title_size + len(new_titles)].copy_( paragraphs[:len(new_titles), :]) if retriever is not None and self.graph_retriever_config.expand_links: expand_links(examples[i].all_paras, examples[i].all_linked_paras_dic, examples[i].all_paras) output = torch.bmm(state_, ps.unsqueeze(0).expand( beam, ps.size(0), ps.size(1)).transpose( 1, 2)) # (beam, 1, N+1) output = output + self.bias output = torch.sigmoid(output) output = output.to(self.cpu) if j == 0: output[:, :, len(examples[i].context):] = 0.0 else: if len(examples[i].title_order) < N: output[:, :, len(examples[i].title_order):N] = 0.0 for b in range(beam): # Omitting previous predictions for k in range(len(pred_[b][0])): output[b, :, pred_[b][0][k]] = 0.0 # Links & topK-based pruning if self.graph_retriever_config.pruning_by_links: if pred_[b][0][-1] == eos_index: output[b, :, :eos_index] = 0.0 output[b, :, eos_index] = 1.0 elif examples[i].title_order[ pred_[b][0] [-1]] not in examples[i].all_linked_paras_dic: for k in range(len(examples[i].title_order)): if k not in pred_[b][1]: output[b, :, k] = 0.0 else: for k in range(len(examples[i].title_order)): if k not in pred_[b][1] and examples[ i].title_order[k] not in examples[ i].all_linked_paras_dic[ examples[i].title_order[ pred_[b][0][-1]]]: output[b, :, k] = 0.0 # always >= M before EOS if j <= self.graph_retriever_config.min_select_num - 1: output[:, :, -1] = 0.0 score = [p[2] for p in pred_] score = torch.FloatTensor(score) score = score.unsqueeze(1).unsqueeze(2) # (beam, 1, 1) score = output * score output = output.squeeze(1) # (beam, N+1) score = score.squeeze(1) # (beam, N+1) new_pred_ = [] new_prob_ = [] b = 0 while b < beam: s, p = torch.max(score.view(score.size(0) * score.size(1)), dim=0) s = s.item() p = p.item() row = p // score.size(1) col = p % score.size(1) if j == 0: score[:, col] = 0.0 else: score[row, col] = 0.0 p = [[index for index in pred_[row][0]] + [col], output[row].topk(k=2, dim=0)[1].tolist(), s] new_pred_.append(p) p = [[p_ for p_ in prb] for prb in prob_[row]] + [output[row].tolist()] new_prob_.append(p) state_tmp[b].copy_(state_[row]) b += 1 pred_ = new_pred_ prob_ = new_prob_ state_ = state_.clone() state_.copy_(state_tmp) if pred_[0][0][-1] == eos_index: break topk_pred.append([]) topk_prob.append([]) for index__ in range(beam): pred_tmp = [] for index in pred_[index__][0]: if index == eos_index: break pred_tmp.append(index) if index__ == 0: pred.append(pred_tmp) prob.append(prob_[0]) topk_pred[-1].append(pred_tmp) topk_prob[-1].append(prob_[index__]) return pred, prob, topk_pred, topk_prob