Esempio n. 1
0
    def forward(
        self, decoder_states, decoder_mask, encoder_states, encoder_mask, decoder_mems_list=None, return_mems=False
    ):
        """
        Args:
            decoder_states: output of the embedding layer (B x L_dec x H)
            decoder_mask: decoder inputs mask (B x L_dec)
            encoder_states: output of the encoder (B x L_enc x H)
            encoder_mask: encoder inputs mask (B x L_enc)
            decoder_mems_list: list of the cached decoder hidden states
                for fast autoregressive generation which will be used instead
                of decoder_states as keys and values if not None
            return_mems: bool, whether to return outputs of all decoder layers
                or the last layer only
        """
        decoder_attn_mask = form_attention_mask(decoder_mask, diagonal=self.diagonal)
        encoder_attn_mask = form_attention_mask(encoder_mask)
        memory_states = self._get_memory_states(decoder_states, decoder_mems_list, 0)
        cached_mems_list = [memory_states]

        for i, layer in enumerate(self.layers):
            decoder_states = layer(decoder_states, decoder_attn_mask, memory_states, encoder_states, encoder_attn_mask)
            memory_states = self._get_memory_states(decoder_states, decoder_mems_list, i + 1)
            cached_mems_list.append(memory_states)

        if return_mems:
            return cached_mems_list
        else:
            return cached_mems_list[-1]
Esempio n. 2
0
    def forward(
        self,
        decoder_states,
        decoder_mask,
        encoder_states,
        encoder_mask,
        decoder_mems_list=None,
        return_mems=False,
        return_mems_as_list=True,
    ):
        """
        Args:
            decoder_states: output of the embedding layer (B x L_dec x H)
            decoder_mask: decoder inputs mask (B x L_dec)
            encoder_states: output of the encoder (B x L_enc x H)
            encoder_mask: encoder inputs mask (B x L_enc)
            decoder_mems_list: list of the cached decoder hidden states
                for fast autoregressive generation which will be used instead
                of decoder_states as keys and values if not None
            return_mems: bool, whether to return outputs of all decoder layers
                or the last layer only
            return_mems_as_list: bool, when True, mems returned are as a list; otherwise mems are Tensor
        """
        decoder_attn_mask = form_attention_mask(decoder_mask,
                                                diagonal=self.diagonal)
        encoder_attn_mask = form_attention_mask(encoder_mask)
        memory_states = self._get_memory_states(decoder_states,
                                                decoder_mems_list, 0)
        if return_mems_as_list:
            cached_mems_list = [memory_states]
        else:
            cached_mems_list = memory_states.unsqueeze(0)

        for i, layer in enumerate(self.layers):
            decoder_states = layer(decoder_states, decoder_attn_mask,
                                   memory_states, encoder_states,
                                   encoder_attn_mask)
            memory_states = self._get_memory_states(decoder_states,
                                                    decoder_mems_list, i + 1)
            if return_mems_as_list:
                cached_mems_list.append(memory_states)
            else:
                cached_mems_list = torch.cat(
                    (cached_mems_list, memory_states.unsqueeze(0)), dim=0)

        if self.final_layer_norm is not None:
            decoder_states = self.final_layer_norm(decoder_states)
            memory_states = self._get_memory_states(decoder_states,
                                                    decoder_mems_list, i + 2)
            if return_mems_as_list:
                cached_mems_list.append(memory_states)
            else:
                cached_mems_list = torch.cat(
                    (cached_mems_list, memory_states.unsqueeze(0)), dim=0)

        if return_mems:
            return cached_mems_list
        else:
            return cached_mems_list[-1]
Esempio n. 3
0
    def forward(self, hidden, hidden_mask=None, return_ortho_loss=False):
        """
        Project hidden [B x N x H] to fixed-size [B x k x H]

        return_ortho_loss - if True returns loss term to encourage
                              orthogonal attention vectors
        """

        attention_scores = self.W2(
            self.act(self.W1(hidden) / self.attn_scale) /
            self.attn_scale).transpose(-1, -2)

        attention_mask = form_attention_mask(hidden_mask)
        if attention_mask is not None:
            attention_mask.squeeze_(1)
            attention_scores = attention_scores + attention_mask.to(
                attention_scores.dtype)

        A = torch.softmax(attention_scores, dim=-1)
        M = A @ hidden

        if return_ortho_loss:
            ortho_loss = ((A @ A.transpose(-1, -2)) -
                          torch.eye(self.k).type_as(A)).pow(2).sum()

            return M, ortho_loss
        else:
            return M
Esempio n. 4
0
 def apply_transformer(self, x, padding_mask=None):
     encoder_attn_mask = form_attention_mask(padding_mask)
     if (
             self.layer_drop and self.training
     ):  # Stochastic layer drop as in: Huang et al. https://arxiv.org/pdf/1603.09382.pdf
         for _, layer in enumerate(self.layers):
             p = random.random()
             if p > self.layer_drop:
                 x = layer(x, encoder_attn_mask, x)
     else:
         for _, layer in enumerate(self.layers):
             x = layer(x, encoder_attn_mask, x)
     return x