def forward( self, decoder_states, decoder_mask, encoder_states, encoder_mask, decoder_mems_list=None, return_mems=False ): """ Args: decoder_states: output of the embedding layer (B x L_dec x H) decoder_mask: decoder inputs mask (B x L_dec) encoder_states: output of the encoder (B x L_enc x H) encoder_mask: encoder inputs mask (B x L_enc) decoder_mems_list: list of the cached decoder hidden states for fast autoregressive generation which will be used instead of decoder_states as keys and values if not None return_mems: bool, whether to return outputs of all decoder layers or the last layer only """ decoder_attn_mask = form_attention_mask(decoder_mask, diagonal=self.diagonal) encoder_attn_mask = form_attention_mask(encoder_mask) memory_states = self._get_memory_states(decoder_states, decoder_mems_list, 0) cached_mems_list = [memory_states] for i, layer in enumerate(self.layers): decoder_states = layer(decoder_states, decoder_attn_mask, memory_states, encoder_states, encoder_attn_mask) memory_states = self._get_memory_states(decoder_states, decoder_mems_list, i + 1) cached_mems_list.append(memory_states) if return_mems: return cached_mems_list else: return cached_mems_list[-1]
def forward( self, decoder_states, decoder_mask, encoder_states, encoder_mask, decoder_mems_list=None, return_mems=False, return_mems_as_list=True, ): """ Args: decoder_states: output of the embedding layer (B x L_dec x H) decoder_mask: decoder inputs mask (B x L_dec) encoder_states: output of the encoder (B x L_enc x H) encoder_mask: encoder inputs mask (B x L_enc) decoder_mems_list: list of the cached decoder hidden states for fast autoregressive generation which will be used instead of decoder_states as keys and values if not None return_mems: bool, whether to return outputs of all decoder layers or the last layer only return_mems_as_list: bool, when True, mems returned are as a list; otherwise mems are Tensor """ decoder_attn_mask = form_attention_mask(decoder_mask, diagonal=self.diagonal) encoder_attn_mask = form_attention_mask(encoder_mask) memory_states = self._get_memory_states(decoder_states, decoder_mems_list, 0) if return_mems_as_list: cached_mems_list = [memory_states] else: cached_mems_list = memory_states.unsqueeze(0) for i, layer in enumerate(self.layers): decoder_states = layer(decoder_states, decoder_attn_mask, memory_states, encoder_states, encoder_attn_mask) memory_states = self._get_memory_states(decoder_states, decoder_mems_list, i + 1) if return_mems_as_list: cached_mems_list.append(memory_states) else: cached_mems_list = torch.cat( (cached_mems_list, memory_states.unsqueeze(0)), dim=0) if self.final_layer_norm is not None: decoder_states = self.final_layer_norm(decoder_states) memory_states = self._get_memory_states(decoder_states, decoder_mems_list, i + 2) if return_mems_as_list: cached_mems_list.append(memory_states) else: cached_mems_list = torch.cat( (cached_mems_list, memory_states.unsqueeze(0)), dim=0) if return_mems: return cached_mems_list else: return cached_mems_list[-1]
def forward(self, hidden, hidden_mask=None, return_ortho_loss=False): """ Project hidden [B x N x H] to fixed-size [B x k x H] return_ortho_loss - if True returns loss term to encourage orthogonal attention vectors """ attention_scores = self.W2( self.act(self.W1(hidden) / self.attn_scale) / self.attn_scale).transpose(-1, -2) attention_mask = form_attention_mask(hidden_mask) if attention_mask is not None: attention_mask.squeeze_(1) attention_scores = attention_scores + attention_mask.to( attention_scores.dtype) A = torch.softmax(attention_scores, dim=-1) M = A @ hidden if return_ortho_loss: ortho_loss = ((A @ A.transpose(-1, -2)) - torch.eye(self.k).type_as(A)).pow(2).sum() return M, ortho_loss else: return M
def apply_transformer(self, x, padding_mask=None): encoder_attn_mask = form_attention_mask(padding_mask) if ( self.layer_drop and self.training ): # Stochastic layer drop as in: Huang et al. https://arxiv.org/pdf/1603.09382.pdf for _, layer in enumerate(self.layers): p = random.random() if p > self.layer_drop: x = layer(x, encoder_attn_mask, x) else: for _, layer in enumerate(self.layers): x = layer(x, encoder_attn_mask, x) return x