def convert_examples_to_features(examples, max_seq_length, max_para_num, graph_retriever_config, tokenizer): """Loads a data file into a list of `InputBatch`s.""" max_para_num = graph_retriever_config.max_context_size graph_retriever_config.max_para_num = max( graph_retriever_config.max_para_num, max_para_num) max_steps = graph_retriever_config.max_select_num DUMMY = [0] * max_seq_length features = [] for (ex_index, example) in enumerate(examples): tokens_q = tokenize_question(example.question, tokenizer) title2index = {} input_ids = [] input_masks = [] segment_ids = [] titles_list = list(example.context.keys()) for p in titles_list: if len(input_ids) == max_para_num: break if p in title2index: continue title2index[p] = len(title2index) example.title_order.append(p) p = example.context[p] input_ids_, input_masks_, segment_ids_ = tokenize_paragraph( p, tokens_q, max_seq_length, tokenizer) input_ids.append(input_ids_) input_masks.append(input_masks_) segment_ids.append(segment_ids_) num_paragraphs_no_links = len(input_ids) assert len(input_ids) <= max_para_num num_paragraphs = len(input_ids) output_masks = [([1.0] * len(input_ids) + [0.0] * (max_para_num - len(input_ids) + 1)) for _ in range(max_para_num + 2)] assert len(example.context) == num_paragraphs_no_links for i in range(len(output_masks[0])): if i >= num_paragraphs_no_links: output_masks[0][i] = 0.0 for i in range(len(input_ids)): output_masks[i + 1][i] = 0.0 padding = [DUMMY] * (max_para_num - len(input_ids)) input_ids += padding input_masks += padding segment_ids += padding features.append( InputFeatures(input_ids=input_ids, input_masks=input_masks, segment_ids=segment_ids, output_masks=output_masks, num_paragraphs=num_paragraphs, num_steps=-1, ex_index=ex_index)) return features
def beam_search(self, input_ids, token_type_ids, attention_mask, examples, tokenizer, retriever, split_chunk): beam = self.graph_retriever_config.beam B = input_ids.size(0) N = self.graph_retriever_config.max_para_num pred = [] prob = [] topk_pred = [] topk_prob = [] eos_index = N init_paragraphs, state = self.encode(input_ids, token_type_ids, attention_mask, split_chunk=split_chunk) # Output matrix to be populated ps = torch.FloatTensor(N + 1, self.s.size(0)).zero_().to( self.s.device) # (N+1, D) for i in range(B): init_context_len = len(examples[i].context) # Populating the output matrix by the initial encoding ps[:init_context_len, :].copy_( init_paragraphs[i, :init_context_len, :]) ps[-1, :].copy_(init_paragraphs[i, -1, :]) encoded_titles = set(examples[i].title_order) pred_ = [ [[], [], 1.0] for _ in range(beam) ] # [hist_1, topk_1, score_1], [hist_2, topk_2, score_2], ... prob_ = [[] for _ in range(beam)] state_ = state[i:i + 1] # (1, 1, D) state_ = state_.expand(beam, 1, state_.size(2)) # -> (beam, 1, D) state_tmp = torch.FloatTensor(state_.size()).zero_().to( state_.device) for j in range(self.graph_retriever_config.max_select_num): if j > 0: input = [p[0][-1] for p in pred_] input = torch.LongTensor(input).to(ps.device) input = ps[input].unsqueeze(1) # (beam, 1, D) state_ = torch.cat((state_, input), dim=2) # (beam, 1, 2*D) state_ = self.rw(state_) # (beam, 1, D) state_ = self.weight_norm(state_) # Opening new links from the previous predictions (pupulating the output matrix dynamically) if j > 0: prev_title_size = len(examples[i].title_order) new_titles = [] for b in range(beam): prev_pred = pred_[b][0][-1] if prev_pred == eos_index: continue prev_title = examples[i].title_order[prev_pred] if prev_title not in examples[i].all_linked_paras_dic: if retriever is None: continue else: linked_paras_dic = retriever.get_hyperlinked_abstract_paragraphs( prev_title, examples[i].question) examples[i].all_linked_paras_dic[ prev_title] = {} examples[i].all_linked_paras_dic[ prev_title].update(linked_paras_dic) examples[i].all_paras.update(linked_paras_dic) for linked_title in examples[i].all_linked_paras_dic[ prev_title]: if linked_title in encoded_titles or len( examples[i].title_order) == N: continue encoded_titles.add(linked_title) new_titles.append(linked_title) examples[i].title_order.append(linked_title) if len(new_titles) > 0: tokens_q = tokenize_question(examples[i].question, tokenizer) input_ids = [] input_masks = [] segment_ids = [] for linked_title in new_titles: linked_para = examples[i].all_paras[linked_title] input_ids_, input_masks_, segment_ids_ = tokenize_paragraph( linked_para, tokens_q, self.graph_retriever_config.max_seq_length, tokenizer) input_ids.append(input_ids_) input_masks.append(input_masks_) segment_ids.append(segment_ids_) input_ids = torch.LongTensor([input_ids]).to(ps.device) token_type_ids = torch.LongTensor([segment_ids ]).to(ps.device) attention_mask = torch.LongTensor([input_masks ]).to(ps.device) paragraphs, _ = self.encode(input_ids, token_type_ids, attention_mask, split_chunk=split_chunk) paragraphs = paragraphs.squeeze(0) ps[prev_title_size:prev_title_size + len(new_titles)].copy_( paragraphs[:len(new_titles), :]) if retriever is not None and self.graph_retriever_config.expand_links: expand_links(examples[i].all_paras, examples[i].all_linked_paras_dic, examples[i].all_paras) output = torch.bmm(state_, ps.unsqueeze(0).expand( beam, ps.size(0), ps.size(1)).transpose( 1, 2)) # (beam, 1, N+1) output = output + self.bias output = torch.sigmoid(output) output = output.to(self.cpu) if j == 0: output[:, :, len(examples[i].context):] = 0.0 else: if len(examples[i].title_order) < N: output[:, :, len(examples[i].title_order):N] = 0.0 for b in range(beam): # Omitting previous predictions for k in range(len(pred_[b][0])): output[b, :, pred_[b][0][k]] = 0.0 # Links & topK-based pruning if self.graph_retriever_config.pruning_by_links: if pred_[b][0][-1] == eos_index: output[b, :, :eos_index] = 0.0 output[b, :, eos_index] = 1.0 elif examples[i].title_order[ pred_[b][0] [-1]] not in examples[i].all_linked_paras_dic: for k in range(len(examples[i].title_order)): if k not in pred_[b][1]: output[b, :, k] = 0.0 else: for k in range(len(examples[i].title_order)): if k not in pred_[b][1] and examples[ i].title_order[k] not in examples[ i].all_linked_paras_dic[ examples[i].title_order[ pred_[b][0][-1]]]: output[b, :, k] = 0.0 # always >= M before EOS if j <= self.graph_retriever_config.min_select_num - 1: output[:, :, -1] = 0.0 score = [p[2] for p in pred_] score = torch.FloatTensor(score) score = score.unsqueeze(1).unsqueeze(2) # (beam, 1, 1) score = output * score output = output.squeeze(1) # (beam, N+1) score = score.squeeze(1) # (beam, N+1) new_pred_ = [] new_prob_ = [] b = 0 while b < beam: s, p = torch.max(score.view(score.size(0) * score.size(1)), dim=0) s = s.item() p = p.item() row = p // score.size(1) col = p % score.size(1) if j == 0: score[:, col] = 0.0 else: score[row, col] = 0.0 p = [[index for index in pred_[row][0]] + [col], output[row].topk(k=2, dim=0)[1].tolist(), s] new_pred_.append(p) p = [[p_ for p_ in prb] for prb in prob_[row]] + [output[row].tolist()] new_prob_.append(p) state_tmp[b].copy_(state_[row]) b += 1 pred_ = new_pred_ prob_ = new_prob_ state_ = state_.clone() state_.copy_(state_tmp) if pred_[0][0][-1] == eos_index: break topk_pred.append([]) topk_prob.append([]) for index__ in range(beam): pred_tmp = [] for index in pred_[index__][0]: if index == eos_index: break pred_tmp.append(index) if index__ == 0: pred.append(pred_tmp) prob.append(prob_[0]) topk_pred[-1].append(pred_tmp) topk_prob[-1].append(prob_[index__]) return pred, prob, topk_pred, topk_prob