def _next_source_patch(self): source_text = self.source_text[self.pr:self.pr + self.step] source_idx = self.source_idx[self.pr:self.pr + self.step] source_length = self.source_length[self.pr:self.pr + self.step] tmp_length = [[4] * l for l in source_length] source_idx, _, source_length = pad_sequence(source_idx, tmp_length, self.padding_token_idx, source_length) batch_data = { 'source_text': source_text, 'source_idx': source_idx.to(self.device), 'source_length': source_length.to(self.device) } if hasattr(self, 'source_plan_idx'): source_plan_idx = self.source_plan_idx[self.pr:self.pr + self.step] source_plan_length = self.source_plan_length[self.pr:self.pr + self.step] source_plan_idx, source_plan_length, _ = pad_sequence( source_plan_idx, source_plan_length, self.padding_token_idx, None) batch_data['source_plan_idx'] = source_plan_idx.to(self.device) batch_data['source_plan_length'] = source_plan_length.to( self.device) return batch_data
def _next_target_patch(self): batch_data = super()._next_target_patch() # target_text target_input_idx = self.target_input_idx[self.pr:self.pr + self.step] target_output_idx = self.target_output_idx[self.pr:self.pr + self.step] target_length = self.target_length[self.pr:self.pr + self.step] target_input_idx, target_length, _ = pad_sequence(target_input_idx, target_length, self.padding_token_idx) target_output_idx, _, _ = pad_sequence(target_output_idx, target_length, self.padding_token_idx) batch_data['target_input_idx'] = target_input_idx.to(self.device) batch_data['target_output_idx'] = target_output_idx.to(self.device) batch_data['target_length'] = target_length.to(self.device) return batch_data
def _next_target_patch(self): r"""Assemble next batch of target data in form of Interaction, and return these data. Returns: Interaction: The next batch of target data. """ target_text = self.target_text[self.pr:self.pr + self.step] if self.target_idx is not None: target_idx = self.target_idx[self.pr:self.pr + self.step] target_length = self.target_length[self.pr:self.pr + self.step] target_num = self.target_num[ self.pr:self.pr + self.step] if self.target_num is not None else None target_idx, target_length, target_num = pad_sequence( target_idx, target_length, self.padding_token_idx, target_num) batch_data = { 'target_text': target_text, 'target_idx': target_idx.to(self.device), 'target_length': target_length.to(self.device) } if target_num is not None: batch_data['target_num'] = target_num return batch_data else: return {'target_text': target_text}
def encoder(self, corpus): # source_text (torch.Tensor): shape: [batch_size, max_seq_len]. source_idx = corpus['source_idx'] # source_length (torch.Tensor): shape: [batch_size]. source_length = corpus['source_length'] """ entity encoder entity_len_list: Length of each entity entity_len: Number of entities in each batch data """ entity_list, entity_len_list, entity_len, graph = self.mkgraph(corpus) entity_list, _, _ = pad_sequence(entity_list, entity_len_list, self.padding_token_idx) entity_list = entity_list.to(self.device) entity_len_list = torch.tensor(entity_len_list).to(self.device) entity_len = torch.tensor(entity_len).to(self.device) _, [entity_embeddings, c0] = self.entity_encoder(self.entity_token_embedder(entity_list), entity_len_list) entity_embeddings = entity_embeddings.transpose(0, 1).contiguous() entity_embeddings = entity_embeddings[:, -2:].view( entity_embeddings.size(0), -1) # relation encoder rel_embeddings = self.rel_token_embedder( torch.arange(15).to(self.device)) # title encoder source_embeddings, _ = self.source_encoder( self.source_token_embedder(source_idx), source_length) # graph encoder entity_embeddings, root_embeddings = self.graph_encoder( entity_embeddings, entity_len, rel_embeddings, graph, self.device) return entity_embeddings, source_embeddings, root_embeddings, entity_len
def _next_source_patch(self): batch_data = super()._next_source_patch() # source_text & source_idx & source_length if self.config['is_pgen']: source_extended_idx = self.source_extended_idx[self.pr:self.pr + self.step] source_extended_idx, _, _ = pad_sequence( source_extended_idx, batch_data['source_length'].cpu(), self.padding_token_idx) source_oovs = self.source_oovs[self.pr:self.pr + self.step] extra_zeros = self.get_extra_zeros(source_oovs) batch_data['source_extended_idx'] = source_extended_idx.to(self.device) batch_data['source_oovs'] = source_oovs batch_data['extra_zeros'] = extra_zeros.to(self.device) return batch_data
def _next_source_patch(self): source_text = self.source_text[self.pr:self.pr + self.step] if self.source_idx is not None: source_idx = self.source_idx[self.pr:self.pr + self.step] source_length = self.source_length[self.pr:self.pr + self.step] source_num = self.source_num[self.pr:self.pr + self.step] if self.source_num is not None else None source_idx, source_length, source_num = pad_sequence( source_idx, source_length, self.padding_token_idx, source_num ) batch_data = { 'source_text': source_text, 'source_idx': source_idx.to(self.device), 'source_length': source_length.to(self.device) } if source_num is not None: batch_data['source_num'] = source_num return batch_data else: return {'source_text': source_text}
def _next_source_patch(self): source_text = self.source_text[self.pr:self.pr + self.step] source_idx = self.source_idx[self.pr:self.pr + self.step] source_length = self.source_length[self.pr:self.pr + self.step] source_triple = self.source_triple[self.pr:self.pr + self.step] source_triple_idx = self.source_triple_idx[self.pr:self.pr + self.step] source_entity = self.source_entity[self.pr:self.pr + self.step] target_mention = self.target_mention[self.pr:self.pr + self.step] source_idx, source_length, _ = pad_sequence(source_idx, source_length, self.padding_token_idx) batch_data = { 'source_text': source_text, 'source_idx': source_idx.to(self.device), 'source_length': source_length.to(self.device), 'source_triple': source_triple, 'source_triple_idx': source_triple_idx, 'source_entity': source_entity, 'target_mention': target_mention, } return batch_data