def test_sum_tensors(self): inputs = [torch.tensor(1), torch.tensor(2)] self.assertEqual(utils.sum_tensors(inputs), torch.tensor(3)) inputs = [torch.tensor(1), None, torch.tensor(2)] self.assertEqual(utils.sum_tensors(inputs), torch.tensor(3)) inputs = [torch.tensor(1), None, None] self.assertEqual(utils.sum_tensors(inputs), torch.tensor(1)) inputs = [None, None, None] self.assertEqual(utils.sum_tensors(inputs), None)
def _forward(self, word_embed: torch.Tensor, segment_ids: Optional[torch.LongTensor] = None, input_mask: Optional[torch.Tensor] = None, memory: Optional[List[torch.Tensor]] = None, permute_mask: Optional[torch.Tensor] = None, target_mapping: Optional[torch.Tensor] = None, bi_data: bool = False, clamp_len: Optional[int] = None, cache_len: int = 0, same_length: bool = False, attn_type: str = 'bi', two_stream: bool = False) \ -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: r"""Compute XLNet representations for the input. This layer exists because :class:`XLNetDecoder` compute embeddings in the decoder helper. `word_embed` has shape `[batch_size, max_time, word_embed_dim]`. Please refer to :meth:`forward` for the detailed information of other arguments. """ # seq_len == max_time # word_embed: [seq_len, batch_size, word_embed_dim] word_embed = word_embed.permute(1, 0, 2) # segment_ids: [seq_len, batch_size] if segment_ids is not None: segment_ids = segment_ids.permute(1, 0) # input_mask: [seq_len, batch_size] if input_mask is not None: input_mask = input_mask.permute(1, 0) # memory: A list of length num_layers # each tensor of shape [mem_len, batch_size, hidden_dim] if memory is not None: memory = [m.permute(1, 0, 2) for m in memory] # permute_mask: [seq_len, seq_len, batch_size] if permute_mask is not None: permute_mask = permute_mask.permute(1, 2, 0) # target_mapping: [num_targets, seq_len, batch_size] if target_mapping is not None: target_mapping = target_mapping.permute(1, 2, 0) seq_len, batch_size = word_embed.size()[:2] mem_len = memory[0].size(0) if memory is not None else 0 tot_len = seq_len + mem_len reuse_len = self._hparams.reuse_len # Construct masks. masks: List[Optional[torch.Tensor]] = [] # Causal attention mask. if attn_type == 'uni': causal_mask = self._create_causal_attn_mask( seq_len, mem_len, same_length) # attn_mask: (seq_len, tot_len, 1, 1) causal_mask = causal_mask.unsqueeze(2).unsqueeze(3) masks.append(causal_mask) elif attn_type == 'bi': pass else: raise ValueError(f"Unsupported attention type: {attn_type}") # Data mask: input mask & permutation mask. if input_mask is not None: input_mask = input_mask.expand(seq_len, -1, -1) data_mask = sum_tensors([input_mask, permute_mask]) if data_mask is not None: # All positions in memory can be attended to. memory_mask = data_mask.new_zeros(seq_len, mem_len, batch_size) # data_mask: (seq_len, tot_len, batch_size, 1) data_mask = torch.cat([memory_mask, data_mask], dim=1).unsqueeze(3) masks.append(data_mask) # Exclude the main diagonal (target tokens) from the mask. attn_mask = sum_tensors(masks) if attn_mask is None: final_mask = None else: attn_mask = (attn_mask > 0) final_mask = -torch.eye(seq_len, device=attn_mask.device) final_mask = torch.cat( [final_mask.new_zeros(seq_len, mem_len), final_mask], dim=-1) final_mask = final_mask.unsqueeze(2).unsqueeze(3) # final_mask: (seq_len, tot_len, batch_size, 1) final_mask = ((attn_mask.float() + final_mask) > 0) # Construct segment embedding. if segment_ids is not None: concat_segment_ids = torch.cat( [segment_ids.new_zeros(mem_len, batch_size), segment_ids]) segment_matrix = (segment_ids.unsqueeze(1) != concat_segment_ids.unsqueeze(0)).long() segment_matrix = F.one_hot(segment_matrix, num_classes=2).float() else: segment_matrix = None pos_embed = self.pos_embed(batch_size, seq_len, tot_len, clamp_len, attn_type, bi_data) pos_embed = self.dropout(pos_embed) states_h = self.dropout(word_embed) states_g = None if two_stream: if target_mapping is not None: word_embed_q = self.mask_emb.expand(target_mapping.size(0), batch_size, -1) else: word_embed_q = word_embed states_g = self.dropout(word_embed_q) new_memory = [] for idx in range(self._hparams.num_layers): cur_memory = memory[idx] if memory is not None else None if cache_len > 0: new_memory.append( self._cache_mem(states_h, cur_memory, cache_len, reuse_len)) attn_layer: RelativeMultiheadAttention attn_layer = self.attn_layers[idx] # type: ignore states_h, states_g = attn_layer(states_h=states_h, states_g=states_g, pos_embed=pos_embed, segment_mat=segment_matrix, attn_mask_h=final_mask, attn_mask_g=attn_mask, target_mapping=target_mapping, memory=cur_memory) states_h = self.ff_layers[idx](states_h) if states_g is not None: states_g = self.ff_layers[idx](states_g) output = self.dropout(states_h if states_g is None else states_g) # Now output: [seq_len, batch_size, hidden_dim] # new_memory: None or A list of length num_layers, # each tensor of shape [cache_len, batch_size, hidden_dim] output = output.permute(1, 0, 2) if new_memory is not None: new_memory = [m.permute(1, 0, 2) for m in new_memory] if cache_len == 0: return output, None return output, new_memory