Exemplo n.º 1
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            self.embed_dim,
            args.decoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self,
                x,
                encoder_out,
                encoder_padding_mask,
                incremental_state,
                prev_self_attn_state=None,
                prev_attn_state=None,
                self_attn_mask=None,
                self_attn_padding_mask=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        attn = None
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
class TransformerBiModalDecoderLayer(nn.Module):
    """Bi-Modal Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, 'cross_self_attention',
                                            False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.audio_encoder_attn = None
            self.video_encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.audio_encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.video_encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc_av = Linear(self.embed_dim * 2, self.embed_dim)
        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        audio_encoder_out=None,
        video_encoder_out=None,
        audio_encoder_padding_mask=None,
        video_encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
        need_attn=True,
        need_head_weights=False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            self.self_attn._set_input_buffer(incremental_state, saved_state)


#        if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)):
#            if self_attn_mask is not None:
#                self_attn_mask = torch.cat(
#                        (x.new
#                            (x.size(0), audio_encoder_out.size(0) + video_encoder_out.size(0)).zero_(), self_attn_mask
#                            ),
#                        dim=1
#                        )
#            if self_attn_padding_mask is not None:
#                if audio_encoder_padding_mask is None and video_encoder_padding_mask is None:
#                    encoder_padding_mask = self_attn_padding_mask.new(
#                            audio_encoder_out.size(1), audio_encoder_out.size(0) + video_encoder_out.size(0)).zero_()
#                self_attn_padding_mask = torch.cat(
#                        (encoder_padding_mask, self_attn_padding_mask),
#                        dim=1
#                        )
#            y = torch.cat((audio_encoder_out, video_encoder_out, x), dim=0)
#        else:
#            y = x

#        x, attn = self.self_attn(
#            query=x,
#            key=y,
#            value=y,
#            key_padding_mask=self_attn_padding_mask,
#            incremental_state=incremental_state,
#            need_weights=False,
#            attn_mask=self_attn_mask,
#        )

        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        if self.audio_encoder_attn is not None and self.video_encoder_attn is not None:
            residual = x
            audio_x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                            x,
                                            before=True)
            video_x = audio_x
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                self.audio_encoder_attn._set_input_buffer(
                    incremental_state, saved_state)
                self.video_encoder_attn._set_input_buffer(
                    incremental_state, saved_state)

            audio_x, audio_attn = self.audio_encoder_attn(
                query=audio_x,
                key=audio_encoder_out,
                value=audio_encoder_out,
                key_padding_mask=None,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            video_x, video_attn = self.video_encoder_attn(
                query=video_x,
                key=video_encoder_out,
                value=video_encoder_out,
                key_padding_mask=None,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = torch.cat((audio_x, video_x), dim=2)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.fc_av(x)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)

        #################################################
        #        import os
        #        import scipy.io as sio
        #        path = "./img/"
        #        filename = "dcm_audio_dec_attn"
        #        i = 0
        #        while os.path.exists(path + filename + "_" + str(i) + ".mat"):
        #            i += 1
        #        filename = filename + '_' + str(i)
        #        sio.savemat(path + filename + ".mat", {filename:audio_attn[0,:,:].cpu().detach().numpy()})
        #        filename = "dcm_video_dec_attn"
        #        i = 0
        #        while os.path.exists(path + filename + "_" + str(i) + ".mat"):
        #            i += 1
        #        filename = filename + '_' + str(i)
        #        sio.savemat(path + filename + ".mat", {filename:video_attn[0,:,:].cpu().detach().numpy()})
        ##################################################

        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            if self_attn_padding_mask is not None:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"], saved_state["prev_key_padding_mask"]
            else:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"]
            return x, attn, self_attn_state

        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
class TransformerBiModalityAttentionDecoderLayer(nn.Module):
    """Bi-Modal Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, 'cross_self_attention',
                                            False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.audio_encoder_attn = None
            self.video_encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.audio_encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.video_encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.MA_1 = Linear(self.embed_dim, 1)
        self.MA_1_sig = nn.Sigmoid()
        self.MA_2 = Linear(self.embed_dim, 1)
        self.MA_2_sig = nn.Sigmoid()
        self.MA_softmax = nn.Softmax(dim=2)

        self.fc_av = Linear(self.embed_dim * 2, self.embed_dim)
        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        audio_encoder_out=None,
        video_encoder_out=None,
        audio_encoder_padding_mask=None,
        video_encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
        need_attn=False,
        need_head_weights=False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        if self.audio_encoder_attn is not None and self.video_encoder_attn is not None:
            residual = x
            audio_x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                            x,
                                            before=True)
            video_x = audio_x
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                self.audio_encoder_attn._set_input_buffer(
                    incremental_state, saved_state)
                self.video_encoder_attn._set_input_buffer(
                    incremental_state, saved_state)

            audio_x, audio_attn = self.audio_encoder_attn(
                query=audio_x,
                key=audio_encoder_out,
                value=audio_encoder_out,
                key_padding_mask=None,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            video_x, video_attn = self.video_encoder_attn(
                query=video_x,
                key=video_encoder_out,
                value=video_encoder_out,
                key_padding_mask=None,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            # modality attention based on encoder-decoder attention #
            # [T,B,F] -> sigmoid([T,B,1]) -> [T,B,2]
            audio_coeff = self.MA_1_sig(self.MA_1(audio_x))
            video_coeff = self.MA_2_sig(self.MA_2(video_x))
            modality_coeff = self.MA_softmax(
                torch.cat((audio_coeff, video_coeff),
                          2)  # idx 0 - audio, idx 1 - video
            )
            audio_x = audio_x * modality_coeff[:, :, 0:1] * 2
            video_x = video_x * modality_coeff[:, :, 1:2] * 2
            x = torch.cat((audio_x, video_x), dim=2)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.fc_av(x)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            if self_attn_padding_mask is not None:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"], saved_state["prev_key_padding_mask"]
            else:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Exemplo n.º 4
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, LayerNum=None):
        super().__init__()

        global tmp_file

        self.args = args
        if not hasattr(self.args, 'mixed_precision'):
            self.args.mixed_precision = False
        if not hasattr(self.args, 'plot_variance'):
            self.args.plot_variance = False
        if not hasattr(self.args, 'plot_gradient'):
            self.args.plot_gradient = False

        self.normalize_before = args.decoder_normalize_before
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, 'cross_self_attention', False)

        self.layer_num = LayerNum
        if 'adaptive' in args.init_type:
            assert not self.normalize_before

            self.self_attn = MultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=not self.cross_self_attention
            )

            assert not no_encoder_attn
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True
            )

            self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
            self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
            
            if 'adaptive-profiling' == args.init_type:
                if not tmp_file:
                    tmp_file = open('profile.ratio.init', 'w')
                self.self_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
                self.encoder_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
                self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
            else:
                if not tmp_file:
                    tmp_file = open('profile.ratio.init', 'r')

                layer_iter, next_value = [float(tup) for tup in tmp_file.readline().split()]
                print('layer_num: {}, layer_iter: {}'.format(self.layer_num, layer_iter))
                assert layer_iter == 3 * self.layer_num + 1
                print('decoder self ratio: {}'.format(next_value))
                self.self_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
                self.self_ratio_change.data.fill_(next_value)

                layer_iter, next_value = [float(tup) for tup in tmp_file.readline().split()]
                print('layer_num: {}, layer_iter: {}'.format(self.layer_num, layer_iter))
                assert layer_iter == 3 * self.layer_num + 2
                print('decoder en ratio: {}'.format(next_value))
                self.encoder_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
                self.encoder_ratio_change.data.fill_(next_value)

                layer_iter, next_value = [float(tup) for tup in tmp_file.readline().split()]
                print('layer_num: {}, layer_iter: {}'.format(self.layer_num, layer_iter))
                assert layer_iter == 3 * self.layer_num + 3
                print('decoder ffn ratio: {}'.format(next_value))
                self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
                self.fc_ratio_change.data.fill_(next_value)

            export = getattr(args, 'char_inputs', False)
            self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) 
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) 
            self.final_layer_norm = LayerNorm(self.embed_dim, export=export) 
        else:
            self.self_attn = MultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=not self.cross_self_attention
            )

            assert not no_encoder_attn
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True
            )
            
            self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
            self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
            if args.init_type == 'looklinear':
                self.fc1.weight.data[int(args.decoder_ffn_embed_dim / 2):, :] = -self.fc1.weight.data[0: int(args.decoder_ffn_embed_dim / 2), :]
                self.fc2.weight.data[:, int(args.decoder_ffn_embed_dim / 2):] = -self.fc2.weight.data[:, 0: int(args.decoder_ffn_embed_dim / 2)]
            else:
                assert args.init_type == 'default'

            export = getattr(args, 'char_inputs', False)
            self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
            if no_encoder_attn:
                self.encoder_attn = None
                self.encoder_attn_layer_norm = None
            else:
                self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
            self.final_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu')
        )
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            self.activation_dropout = getattr(args, 'relu_dropout', 0)


        self.need_attn = True

        self.onnx_trace = False

        if args.fp16:
            self.in_type=torch.half
        else:
            self.in_type=torch.float

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
        need_attn=False,
        need_head_weights=False,
    ):
        not_initialized = ('adaptive-profiling' == self.args.init_type) and (1.0 == self.self_ratio_change.min()) and self.training

        if need_head_weights:
            need_attn = True

        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)

        if self.args.mixed_precision: 
            x = x.type(self.in_type)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)):
            if self_attn_mask is not None:
                self_attn_mask = torch.cat((x.new(x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    encoder_padding_mask = self_attn_padding_mask.new(encoder_out.size(1), encoder_out.size(0)).zero_()
                self_attn_padding_mask = torch.cat((encoder_padding_mask, self_attn_padding_mask), dim=1)
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x
        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.args.mixed_precision:
            x = x.float()
        if 'adaptive' in self.args.init_type:
            if not_initialized:
                global decoder_ratio, tmp_file
                tmp_layer_ind = self.layer_num * 3 + 1
                tmp_weight = tmp_layer_ind ** self.args.adaptive_scale
                tmp_ratio = decoder_ratio / tmp_weight
                tmp_file.write('{} {}\n'.format(tmp_layer_ind, tmp_ratio))
                self.self_ratio_change.data.fill_(tmp_ratio)
                print ('decoder self attn ratio: {}'.format(tmp_ratio))
                input_std = np.var( (residual*self.self_ratio_change).clone().cpu().float().data.view(-1).numpy())
                output_std = np.var(x.clone().cpu().float().data.view(-1).numpy())
                decoder_ratio = np.sqrt(input_std + output_std) * tmp_weight
            x0 = x + residual * self.self_ratio_change
        else:
            x0 = residual + x
        x0 = self.maybe_layer_norm(self.self_attn_layer_norm, x0, after=True)
        if self.args.plot_gradient:
            x0.register_hook(lambda grad: print('{} decoder s-att: {}'.format(self.layer_num, grad.norm().item())))
        x = x0
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x0, before=True)

            if self.args.mixed_precision: 
                x = x.type(self.in_type)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout, training=self.training)

            if self.args.mixed_precision:
                x = x.float()
            if 'adaptive' in self.args.init_type:
                if not_initialized:
                    tmp_layer_ind = self.layer_num * 3 + 2
                    tmp_weight = tmp_layer_ind ** self.args.adaptive_scale
                    tmp_ratio = decoder_ratio / tmp_weight
                    tmp_file.write('{} {}\n'.format(tmp_layer_ind, tmp_ratio))
                    self.encoder_ratio_change.data.fill_(tmp_ratio)
                    print ('decoder encoder attn ratio: {}'.format(tmp_ratio))
                    input_std = np.var( (residual*self.encoder_ratio_change).clone().cpu().float().data.view(-1).numpy())
                    output_std = np.var(x.clone().cpu().float().data.view(-1).numpy())
                    decoder_ratio = np.sqrt(input_std + output_std) * tmp_weight
                x1 = x + residual * self.encoder_ratio_change
            else:
                x1 = residual + x
            x1 = self.maybe_layer_norm(self.encoder_attn_layer_norm, x1, after=True)
            if self.args.plot_gradient:
                x1.register_hook(lambda grad: print('{} decoder e-att: {}'.format(self.layer_num, grad.norm().item())))
            x = x1
        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)

        if self.args.mixed_precision: 
            x = x.type(self.in_type)
        bx = self.fc1(x)
        hx = self.activation_fn(bx)
        x = F.dropout(hx, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.args.mixed_precision:
            x = x.float()
        if 'adaptive' in self.args.init_type:
            if not_initialized:
                tmp_layer_ind = self.layer_num * 3 + 3
                tmp_weight = tmp_layer_ind ** self.args.adaptive_scale
                tmp_ratio = decoder_ratio / tmp_weight
                tmp_file.write('{} {}\n'.format(tmp_layer_ind, tmp_ratio))
                self.fc_ratio_change.data.fill_(tmp_ratio)
                print ('decoder ffn ratio: {}'.format(tmp_ratio))
                input_var = np.var( (residual * self.fc_ratio_change) .clone().cpu().float().data.view(-1).numpy())
                output_var = np.var(x.clone().cpu().float().data.view(-1).numpy())
                decoder_ratio = np.sqrt(input_var + output_var) * tmp_weight
            x2 = x + residual * self.fc_ratio_change
        else:
            x2 = residual + x
        x2 = self.maybe_layer_norm(self.final_layer_norm, x2, after=True)
        if self.args.plot_gradient:
            x2.register_hook(lambda grad: print('{} decoder ffn: {}'.format(self.layer_num, grad.norm().item())))

        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            if self_attn_padding_mask is not None:
                self_attn_state = saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"]
            else:
                self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
            return x2, attn, self_attn_state

        return x2, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Exemplo n.º 5
0
class TransformerDecoderLayerBN(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, 'cross_self_attention',
                                            False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_batch_norm = AdaptiveBN(255)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_batch_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_batch_norm = AdaptiveBN(255)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_batch_norm = AdaptiveBN(255)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
        need_attn=False,
        need_head_weights=False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        residual = x
        x = self.maybe_batch_norm(self.self_attn_batch_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        if self.cross_self_attention and not (
                incremental_state is not None and "prev_key"
                in self.self_attn._get_input_buffer(incremental_state)):
            if self_attn_mask is not None:
                self_attn_mask = torch.cat((x.new(
                    x.size(0), encoder_out.size(0)).zero_(), self_attn_mask),
                                           dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    encoder_padding_mask = self_attn_padding_mask.new(
                        encoder_out.size(1), encoder_out.size(0)).zero_()
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1)
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_batch_norm(self.self_attn_batch_norm, x, after=True)

        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_batch_norm(self.encoder_attn_batch_norm,
                                      x,
                                      before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_batch_norm(self.encoder_attn_batch_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_batch_norm(self.final_batch_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_batch_norm(self.final_batch_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            if self_attn_padding_mask is not None:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"], saved_state["prev_key_padding_mask"]
            else:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_batch_norm(self, batch_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            seq_len, bsz, embed_dim = x.size()
            x = x.transpose(0, 1)
            x = batch_norm(x)
            x = x.transpose(0, 1)
            return x
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Exemplo n.º 6
0
class TransformerDecoderLayerPhase2(nn.Module):
    """Second phase of decoder layer block
    This layer will take the input from the ecoder and phirst pass decoder.
    papers.nips.cc/paper/6775-deliberation-networks-sequence-generation-beyond-one-pass-decoding.pdf
    """
    def __init__(
        self,
        args,
        no_encoder_decoder_attn=False,
        add_bias_kv=False,
        add_zero_attn=False,
    ):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=True,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu"))
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_decoder_attn:
            self.encoder_attn = None
            self.decoder_attn = None
            self.encoder_layer_norm = None
            self.decoder_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.decoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)
            self.decoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        decoder_out=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_encoder_attn_state=None,
        prev_decoder_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x_self_attention = self.maybe_layer_norm(self.self_attn_layer_norm,
                                                 x,
                                                 before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x_self_attention, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x_self_attention = F.dropout(x_self_attention,
                                     p=self.dropout,
                                     training=self.training)
        x_self_attention = residual + x_self_attention
        x_self_attention = self.maybe_layer_norm(self.self_attn_layer_norm,
                                                 x_self_attention,
                                                 after=True)

        if self.encoder_attn is not None:
            residual = x
            x_encoder_attention = self.maybe_layer_norm(
                self.encoder_attn_layer_norm, x, before=True)
            if prev_encoder_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_encoder_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x_encoder_attention, attn = self.encoder_attn(
                query=x_encoder_attention,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x_encoder_attention = F.dropout(x_encoder_attention,
                                            p=self.dropout,
                                            training=self.training)
            x_encoder_attention = residual + x_encoder_attention
            x_encoder_attention = self.maybe_layer_norm(
                self.encoder_attn_layer_norm, x_encoder_attention, after=True)

        if self.decoder_attn is not None:
            residual = x
            x_decoder_attention = self.maybe_layer_norm(
                self.decoder_attn_layer_norm, x, before=True)
            if prev_decoder_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_decoder_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x_decoder_attention, attn = self.decoder_attn(
                query=x_decoder_attention,
                key=decoder_out,
                value=decoder_out,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x_decoder_attention = F.dropout(x_decoder_attention,
                                            p=self.dropout,
                                            training=self.training)
            x_decoder_attention = residual + x_decoder_attention
            x_decoder_attention = self.maybe_layer_norm(
                self.encoder_attn_layer_norm, x_decoder_attention, after=True)
        x = x_self_attention + x_encoder_attention + x_decoder_attention

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Exemplo n.º 7
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        embed_dim = self.embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention",
                                            False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        #        self.dropout = [0.05, 0.1, 0.25, 0.3]
        #        self.dropout = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.3]
        self.dropout = [
            0, 0, 0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3
        ]
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu"))
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        #        self.self_attn_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])
        self.self_attn_layer_norm = SlimmableLayernorm([
            int(embed_dim * 4 / 16),
            int(embed_dim * 5 / 16),
            int(embed_dim * 6 / 16),
            int(embed_dim * 7 / 16),
            int(embed_dim * 8 / 16),
            int(embed_dim * 9 / 16),
            int(embed_dim * 10 / 16),
            int(embed_dim * 11 / 16),
            int(embed_dim * 12 / 16),
            int(embed_dim * 13 / 16),
            int(embed_dim * 14 / 16),
            int(embed_dim * 15 / 16), embed_dim
        ])

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = SlimmableLayernorm([
                int(self.embed_dim / 4),
                int(self.embed_dim * 2 / 4),
                int(self.embed_dim * 3 / 4), self.embed_dim
            ])

            #        self.fc1 = SLinear([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim],
            #                           [int(args.decoder_ffn_embed_dim / 4), int(args.decoder_ffn_embed_dim * 2 / 4),int(args.decoder_ffn_embed_dim * 3 / 4), args.decoder_ffn_embed_dim])
            #        self.fc2 = SLinear([int(args.decoder_ffn_embed_dim / 4), int(args.decoder_ffn_embed_dim * 2 / 4), int(args.decoder_ffn_embed_dim * 3 / 4), args.decoder_ffn_embed_dim],
            #                           [int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])

            #        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
            #        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

            #        self.final_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])
            self.encoder_attn_layer_norm = SlimmableLayernorm([
                int(embed_dim * 4 / 16),
                int(embed_dim * 5 / 16),
                int(embed_dim * 6 / 16),
                int(embed_dim * 7 / 16),
                int(embed_dim * 8 / 16),
                int(embed_dim * 9 / 16),
                int(embed_dim * 10 / 16),
                int(embed_dim * 11 / 16),
                int(embed_dim * 12 / 16),
                int(embed_dim * 13 / 16),
                int(embed_dim * 14 / 16),
                int(embed_dim * 15 / 16), embed_dim
            ])

        self.fc1 = SLinear([
            int(embed_dim * 4 / 16),
            int(embed_dim * 5 / 16),
            int(embed_dim * 6 / 16),
            int(embed_dim * 7 / 16),
            int(embed_dim * 8 / 16),
            int(embed_dim * 9 / 16),
            int(embed_dim * 10 / 16),
            int(embed_dim * 11 / 16),
            int(embed_dim * 12 / 16),
            int(embed_dim * 13 / 16),
            int(embed_dim * 14 / 16),
            int(embed_dim * 15 / 16), embed_dim
        ], [
            int(args.encoder_ffn_embed_dim * 4 / 16),
            int(args.encoder_ffn_embed_dim * 5 / 16),
            int(args.encoder_ffn_embed_dim * 6 / 16),
            int(args.encoder_ffn_embed_dim * 7 / 16),
            int(args.encoder_ffn_embed_dim * 8 / 16),
            int(args.encoder_ffn_embed_dim * 9 / 16),
            int(args.encoder_ffn_embed_dim * 10 / 16),
            int(args.encoder_ffn_embed_dim * 11 / 16),
            int(args.encoder_ffn_embed_dim * 12 / 16),
            int(args.encoder_ffn_embed_dim * 13 / 16),
            int(args.encoder_ffn_embed_dim * 14 / 16),
            int(args.encoder_ffn_embed_dim * 15 / 16),
            args.encoder_ffn_embed_dim
        ])

        self.fc2 = SLinear([
            int(args.encoder_ffn_embed_dim * 4 / 16),
            int(args.encoder_ffn_embed_dim * 5 / 16),
            int(args.encoder_ffn_embed_dim * 6 / 16),
            int(args.encoder_ffn_embed_dim * 7 / 16),
            int(args.encoder_ffn_embed_dim * 8 / 16),
            int(args.encoder_ffn_embed_dim * 9 / 16),
            int(args.encoder_ffn_embed_dim * 10 / 16),
            int(args.encoder_ffn_embed_dim * 11 / 16),
            int(args.encoder_ffn_embed_dim * 12 / 16),
            int(args.encoder_ffn_embed_dim * 13 / 16),
            int(args.encoder_ffn_embed_dim * 14 / 16),
            int(args.encoder_ffn_embed_dim * 15 / 16),
            args.encoder_ffn_embed_dim
        ], [
            int(embed_dim * 4 / 16),
            int(embed_dim * 5 / 16),
            int(embed_dim * 6 / 16),
            int(embed_dim * 7 / 16),
            int(embed_dim * 8 / 16),
            int(embed_dim * 9 / 16),
            int(embed_dim * 10 / 16),
            int(embed_dim * 11 / 16),
            int(embed_dim * 12 / 16),
            int(embed_dim * 13 / 16),
            int(embed_dim * 14 / 16),
            int(embed_dim * 15 / 16), embed_dim
        ])

        self.final_layer_norm = SlimmableLayernorm([
            int(embed_dim * 4 / 16),
            int(embed_dim * 5 / 16),
            int(embed_dim * 6 / 16),
            int(embed_dim * 7 / 16),
            int(embed_dim * 8 / 16),
            int(embed_dim * 9 / 16),
            int(embed_dim * 10 / 16),
            int(embed_dim * 11 / 16),
            int(embed_dim * 12 / 16),
            int(embed_dim * 13 / 16),
            int(embed_dim * 14 / 16),
            int(embed_dim * 15 / 16), embed_dim
        ])

        #        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        #        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = SlimmableLayernorm([
            int(embed_dim * 4 / 16),
            int(embed_dim * 5 / 16),
            int(embed_dim * 6 / 16),
            int(embed_dim * 7 / 16),
            int(embed_dim * 8 / 16),
            int(embed_dim * 9 / 16),
            int(embed_dim * 10 / 16),
            int(embed_dim * 11 / 16),
            int(embed_dim * 12 / 16),
            int(embed_dim * 13 / 16),
            int(embed_dim * 14 / 16),
            int(embed_dim * 15 / 16), embed_dim
        ])

        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        idx,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x, index)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.self_attn._get_input_buffer(
            incremental_state)
        if self.cross_self_attention and not (
                incremental_state is not None and _self_attn_input_buffer
                is not None and "prev_key" in _self_attn_input_buffer):
            if self_attn_mask is not None:
                assert encoder_out is not None
                self_attn_mask = torch.cat((x.new_zeros(
                    x.size(0), encoder_out.size(0)), self_attn_mask),
                                           dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out.size(1), encoder_out.size(0))
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1)
            assert encoder_out is not None
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            index=idx,
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout[idx], training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x, idx)

        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x, idx)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            x, attn = self.encoder_attn(
                index=idx,
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout[idx], training=self.training)
            x = residual + x
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x, idx)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x, idx)
        x = self.activation_fn(self.fc1(x, idx))
        x = F.dropout(x,
                      p=float(self.activation_dropout),
                      training=self.training)
        x = self.fc2(x, idx)
        x = F.dropout(x, p=self.dropout[idx], training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x, idx)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:
                self_attn_state = [
                    saved_state["prev_key"], saved_state["prev_value"]
                ]
            return x, attn, self_attn_state
        return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn
Exemplo n.º 8
0
class TarcTransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(
        self, args, no_encoder_attn=False, num_cross_attentions=0, add_bias_kv=False, add_zero_attn=False
    ):
        super().__init__()
        self.num_cross_attentions = num_cross_attentions
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention", False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu")
        )
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        # This is my main modification: cross-attentions to attend the other decoder outputs
        self.cross_attentions = nn.ModuleList()
        self.cross_attentions_norm = nn.ModuleList()
        for i in range( num_cross_attentions ):
            self.cross_attentions.append(
                MultiheadAttention(
                    self.embed_dim,
                    args.decoder_attention_heads,
                    kdim=self.embed_dim,
                    vdim=self.embed_dim,
                    dropout=args.attention_dropout,
                    encoder_decoder_attention=True,
                )
            )
            self.cross_attentions_norm.append(
                LayerNorm(self.embed_dim, export=export)
            )

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out: Optional[List[torch.Tensor]] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        prev_cross_attn_state: Optional[List[List[torch.Tensor]]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        assert len(self.cross_attentions)+1 == len(encoder_out)

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
        if self.cross_self_attention and not (
            incremental_state is not None
            and _self_attn_input_buffer is not None
            and "prev_key" in _self_attn_input_buffer
        ):
            if self_attn_mask is not None:
                assert encoder_out[0] is not None
                self_attn_mask = torch.cat(
                    (x.new_zeros(x.size(0), encoder_out[0].size(0)), self_attn_mask), dim=1
                )
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out[0] is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out[0].size(1), encoder_out[0].size(0)
                    )
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1
                )
            assert encoder_out[0] is not None
            y = torch.cat((encoder_out[0], x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        cross_attn_x = x
        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out[0],
                value=encoder_out[0],
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        if self.num_cross_attentions > 0:
            residual = cross_attn_x
            all_att_output = torch.zeros_like(cross_attn_x)
            if self.normalize_before:
                cross_attn_x = self.cross_attentions_norm[0](cross_attn_x)
            for i in range( len(self.cross_attentions) ):
                if prev_cross_attn_state is not None:
                    prev_key, prev_value = prev_cross_attn_state[i][:2]
                    cross_saved_state: Dict[str, Optional[Tensor]] = {
                        "prev_key": prev_key,
                        "prev_value": prev_value,
                    }
                    if len(prev_cross_attn_state[i]) >= 3:
                        cross_saved_state["prev_key_padding_mask"] = prev_cross_attn_state[i][2]
                    assert incremental_state is not None
                    self.cross_attentions[i]._set_input_buffer(incremental_state, cross_saved_state)

                att_output, attn = self.cross_attentions[i](
                    query=cross_attn_x,
                    key=encoder_out[i+1],
                    value=encoder_out[i+1],
                    key_padding_mask=None,
                    incremental_state=incremental_state,
                    static_kv=True,
                    need_weights=need_attn or (not self.training and self.need_attn),
                    need_head_weights=need_head_weights,
                )
                att_output = F.dropout(att_output, p=self.dropout, training=self.training) 
                all_att_output = att_output + all_att_output
            if self.encoder_attn is not None:
                x = x + all_att_output  # encoder_attn and cross_attentions use the same residual, so no need to add it twice
            else:
                x = residual + x + all_att_output
            if not self.normalize_before:
                x = self.cross_attentions_norm[0](x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=float(self.activation_dropout), training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:
                self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
            return x, attn, self_attn_state
        return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn

    @torch.jit.export
    def reorder_incremental_state(
        self,
        incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
        new_order: Tensor,
    ):
        """Scriptable reorder incremental state in transformer layers."""
        self.self_attn.reorder_incremental_state(incremental_state, new_order)

        if self.encoder_attn is not None:
            self.encoder_attn.reorder_incremental_state(incremental_state, new_order)

        if self.num_cross_attentions > 0:
            [attn.reorder_incremental_state(incremental_state, new_order) for attn in self.cross_attentions]
Exemplo n.º 9
0
Arquivo: bgt.py Projeto: jwcmu/bgt
class TransformerSentenceEmbeddingDecoderLayer(TransformerDecoderLayer):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 add_bias_kv=False,
                 add_zero_attn=False,
                 do_trans=True):
        super().__init__(args)
        self.embed_dim = args.decoder_embed_dim
        self.args = args
        self.do_trans = do_trans
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=True)
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

        if do_trans:
            self.decoder_fc1 = Linear(self.embed_dim + self.args.latent_size,
                                      self.embed_dim)
        else:
            self.decoder_fc1 = Linear(
                self.embed_dim + self.args.latent_size * 2, self.embed_dim)

    def forward(
        self,
        x,
        sent_emb=None,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
        size = (x.size()[0], x.size()[1], sent_emb.size()[-1])

        concat_sent_emb = torch.cat((x, sent_emb.expand(size)), dim=2)
        x = self.decoder_fc1(concat_sent_emb)
        F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn
Exemplo n.º 10
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 self_attn_pattern=None,
                 encoder_attn_pattern=None,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        if self_attn_pattern is not None and args.PRUNE_DEC_SELF_ATTN:
            cpu_np_pattern = self_attn_pattern.cpu().numpy()
            d1_bounds, d2_bounds = find_bounds(cpu_np_pattern)
            prune_random = False
            try:
                if args.RANDOM_PRUNE:  #random prune
                    prune_random = True
                    self.self_attn_mask = torch.from_numpy(
                        random_mask(cpu_np_pattern, args.TAU))
                    if args.CUDA:
                        self.self_attn_mask = self.self_attn_mask.cuda()
            except:
                pass
            if not prune_random:
                cpu_np_pattern = cpu_np_pattern[:, :, 0:d1_bounds, 0:d2_bounds]
                target_percentile = args.TAU * 100
                threshold = np.percentile(cpu_np_pattern,
                                          target_percentile,
                                          interpolation='nearest')
                self.self_attn_mask = (self_attn_pattern <= threshold)
        else:
            self.self_attn_mask = None

        if encoder_attn_pattern is not None and args.PRUNE_ENC_DEC_ATTN:
            cpu_np_pattern = encoder_attn_pattern.cpu().numpy()
            d1_bounds, d2_bounds = find_bounds(cpu_np_pattern)
            prune_random = False
            try:
                if args.RANDOM_PRUNE:  #random prune
                    prune_random = True
                    self.encoder_attn_mask = torch.from_numpy(
                        random_mask(cpu_np_pattern, args.TAU))
                    if args.CUDA:
                        self.encoder_attn_mask = self.encoder_attn_mask.cuda()
            except:
                pass
            if not prune_random:
                cpu_np_pattern = cpu_np_pattern[:, :, 0:d1_bounds, 0:d2_bounds]
                target_percentile = args.TAU * 100
                threshold = np.percentile(cpu_np_pattern,
                                          target_percentile,
                                          interpolation='nearest')
                self.encoder_attn_mask = (encoder_attn_pattern <= threshold)
        else:
            self.encoder_attn_mask = None

        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention",
                                            False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
            args=args,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu"))
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
                args=args,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.self_attn._get_input_buffer(
            incremental_state)
        if self.cross_self_attention and not (
                incremental_state is not None and _self_attn_input_buffer
                is not None and "prev_key" in _self_attn_input_buffer):
            if self_attn_mask is not None:
                assert encoder_out is not None
                self_attn_mask = torch.cat((x.new_zeros(
                    x.size(0), encoder_out.size(0)), self_attn_mask),
                                           dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out.size(1), encoder_out.size(0))
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1)
            assert encoder_out is not None
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
            prune_attn_mask=self.self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            x, attn = self.encoder_attn(query=x,
                                        key=encoder_out,
                                        value=encoder_out,
                                        key_padding_mask=encoder_padding_mask,
                                        incremental_state=incremental_state,
                                        static_kv=True,
                                        need_weights=need_attn or
                                        (not self.training and self.need_attn),
                                        need_head_weights=need_head_weights,
                                        prune_attn_mask=self.encoder_attn_mask)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x,
                      p=float(self.activation_dropout),
                      training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:
                self_attn_state = [
                    saved_state["prev_key"], saved_state["prev_value"]
                ]
            return x, attn, self_attn_state
        return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn

    @torch.jit.export
    def reorder_incremental_state(
        self,
        incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
        new_order: Tensor,
    ):
        """Scriptable reorder incremental state in transformer layers."""
        self.self_attn.reorder_incremental_state(incremental_state, new_order)

        if self.encoder_attn is not None:
            self.encoder_attn.reorder_incremental_state(
                incremental_state, new_order)
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False
        self.perm_order = getattr(args, 'decoder_perm_order', 0)
        assert isinstance(self.perm_order, int) and 0 <= self.perm_order <= 5

    def set_perm_order(self, perm_order=0):
        assert isinstance(perm_order, int) and 0 <= perm_order <= 5
        self.perm_order = perm_order

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        def func0(x):
            nonlocal incremental_state
            residual = x
            x = self.maybe_layer_norm(self.self_attn_layer_norm,
                                      x,
                                      before=True)
            if prev_self_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_self_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.self_attn._set_input_buffer(incremental_state,
                                                 saved_state)
            x, attn = self.self_attn(
                query=x,
                key=x,
                value=x,
                key_padding_mask=self_attn_padding_mask,
                incremental_state=incremental_state,
                need_weights=False,
                attn_mask=self_attn_mask,
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
            return x

        def func1(x):
            nonlocal incremental_state
            if self.encoder_attn is not None:
                residual = x
                x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                          x,
                                          before=True)
                if prev_attn_state is not None:
                    if incremental_state is None:
                        incremental_state = {}
                    prev_key, prev_value = prev_attn_state
                    saved_state = {
                        "prev_key": prev_key,
                        "prev_value": prev_value
                    }
                    self.encoder_attn._set_input_buffer(
                        incremental_state, saved_state)
                nonlocal attn
                x, attn = self.encoder_attn(
                    query=x,
                    key=encoder_out,
                    value=encoder_out,
                    key_padding_mask=encoder_padding_mask,
                    incremental_state=incremental_state,
                    static_kv=True,
                    need_weights=(not self.training and self.need_attn),
                )
                x = F.dropout(x, p=self.dropout, training=self.training)
                x = residual + x
                x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                          x,
                                          after=True)
                return x

        def func2(x):
            residual = x
            x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
            x = self.activation_fn(self.fc1(x))
            x = F.dropout(x, p=self.activation_dropout, training=self.training)
            x = self.fc2(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
            return x

        def get_order():

            order = self.perm_order
            nonlocal x
            if order == 0:
                x = func0(x)
                x = func1(x)
                x = func2(x)
            elif order == 1:
                x = func2(x)
                x = func0(x)
                x = func1(x)
            elif order == 2:
                x = func1(x)
                x = func2(x)
                x = func0(x)
            elif order == 3:
                x = func1(x)
                x = func0(x)
                x = func2(x)
            elif order == 4:
                x = func0(x)
                x = func2(x)
                x = func1(x)
            else:
                x = func2(x)
                x = func1(x)
                x = func0(x)
            return x

        attn = None
        x = get_order()

        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention",
                                            False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu"))
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
        encoder_out2=None,
        balance_weight=None,
        encoder_padding_mask2=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        #print(encoder_out2)
        #print(balance_weight)

        attn2 = None

        if need_head_weights:
            need_attn = True

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            print("prev_self_attn_state not None")
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        _self_attn_input_buffer = self.self_attn._get_input_buffer(
            incremental_state)

        if self.cross_self_attention and not (
                incremental_state is not None and _self_attn_input_buffer
                is not None and "prev_key" in _self_attn_input_buffer):
            if self_attn_mask is not None:
                assert encoder_out is not None
                self_attn_mask = torch.cat((x.new_zeros(
                    x.size(0), encoder_out.size(0)), self_attn_mask),
                                           dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out.size(1), encoder_out.size(0))
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1)
            assert encoder_out is not None
            y = torch.cat((encoder_out, x), dim=0)

            #print("Here cross self")
        else:
            #print("Here not cross self")
            y = x

        #print("self_attn")

        #print('input x', x)

        x, attn, _ = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        #print("end self_attn")

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        #print('output x', x)

        if self.encoder_attn is not None:
            #print("Not None")
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

            if prev_attn_state is not None:

                #print("prev_attn_state not None")

                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            #TODO: CJA

            #print("------", encoder_out.shape, x.shape)

            #print('input', x)
            #print('encoder_out', encoder_out)

            if encoder_out2 is not None:
                if balance_weight is not None:
                    #print("need_head_weights", need_head_weights)
                    #print("Incremental_state: ",incremental_state )
                    if need_head_weights:

                        x, attn, attn2 = self.encoder_attn(
                            query=x,
                            key=encoder_out,
                            value=encoder_out,
                            key2=encoder_out2,
                            value2=encoder_out2,
                            balance_weight=balance_weight,
                            key_padding_mask=encoder_padding_mask,
                            key_padding_mask2=encoder_padding_mask2,
                            incremental_state=incremental_state,
                            static_kv=True,
                            need_weights=need_attn
                            or (not self.training and self.need_attn),
                            need_head_weights=need_head_weights,
                        )
                        #print('output', x)

                    else:

                        #print('here')
                        #print('incremental_state', incremental_state)

                        x, attn, _ = self.encoder_attn(
                            query=x,
                            key=encoder_out,
                            value=encoder_out,
                            key2=encoder_out2,
                            value2=encoder_out2,
                            balance_weight=balance_weight,
                            key_padding_mask=encoder_padding_mask,
                            key_padding_mask2=encoder_padding_mask2,
                            incremental_state=incremental_state,
                            static_kv=True,
                            need_weights=need_attn
                            or (not self.training and self.need_attn),
                            need_head_weights=need_head_weights,
                        )

                else:

                    #print("Incremental_state: ",incremental_state )

                    #print(encoder_out.shape)
                    #print(encoder_out2.shape)

                    concat_out = torch.cat([encoder_out, encoder_out2], dim=0)
                    #print(".......")
                    #print(encoder_out.shape, encoder_out2.shape)
                    #print(concat_out.shape)
                    #print("1: ", encoder_padding_mask.shape)
                    #print("2: ",encoder_padding_mask2.shape)

                    encoder_padding = torch.cat(
                        [encoder_padding_mask, encoder_padding_mask2], dim=1)
                    #print(encoder_padding_mask[0], encoder_padding_mask2[0])
                    #print(encoder_padding.shape)

                    x, attn, _ = self.encoder_attn(
                        query=x,
                        key=concat_out,
                        value=concat_out,
                        key_padding_mask=encoder_padding,
                        incremental_state=incremental_state,
                        static_kv=True,
                        need_weights=need_attn
                        or (not self.training and self.need_attn),
                        need_head_weights=need_head_weights,
                    )

            else:
                x, attn, _ = self.encoder_attn(
                    query=x,
                    key=encoder_out,
                    value=encoder_out,
                    key_padding_mask=encoder_padding_mask,
                    incremental_state=incremental_state,
                    static_kv=True,
                    need_weights=need_attn
                    or (not self.training and self.need_attn),
                    need_head_weights=need_head_weights,
                )

                #print("!!!!!!!", x.shape)

            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x,
                      p=float(self.activation_dropout),
                      training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x)

        if self.onnx_trace and incremental_state is not None:

            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:

                #print("here")

                self_attn_state = [
                    saved_state["prev_key"], saved_state["prev_value"]
                ]
            return x, attn, self_attn_state

        if attn2 is not None:
            return x, attn, attn2
        else:
            return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn
Exemplo n.º 13
0
class MaskDecoderLayer(nn.Module):
    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            self.embed_dim,
            args.decoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.source_encoder_attn = None
            self.mask_encoder_attn = None
            self.encoder_attn_layer_norm = None
            self.concat_dense = None
        else:
            self.source_encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.mask_encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
            self.concat_dense = Linear(2 * self.embed_dim,
                                       self.embed_dim,
                                       bias=True)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self,
                x,
                source_encoder_out,
                source_encoder_padding_mask,
                mask_encoder_out,
                mask_encoder_padding_mask,
                incremental_state,
                prev_self_attn_state=None,
                prev_source_attn_state=None,
                prev_mask_attn_state=None,
                self_attn_mask=None,
                self_attn_padding_mask=None):
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        attn_source = None
        attn_mask = None
        if self.source_encoder_attn is not None:
            residual = x
            source_x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                             x,
                                             before=True)
            mask_x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                           x,
                                           before=True)

            self.set_attention_input_buffer(self.source_encoder_attn,
                                            incremental_state,
                                            prev_source_attn_state)
            self.set_attention_input_buffer(self.mask_encoder_attn,
                                            incremental_state,
                                            prev_mask_attn_state)

            source_x, attn_source = self.source_encoder_attn(
                query=source_x,
                key=source_encoder_out,
                value=source_encoder_out,
                key_padding_mask=source_encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )

            mask_x, attn_mask = self.mask_encoder_attn(
                query=mask_x,
                key=mask_encoder_out,
                value=mask_encoder_out,
                key_padding_mask=mask_encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = torch.cat([source_x, mask_x], dim=-1)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = F.relu(self.concat_dense(x))
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn_source, attn_mask, self_attn_state
        return x, attn_source, attn_mask

    def set_attention_input_buffer(self, attention_layer, incremental_state,
                                   previous_attn_state):
        if previous_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = previous_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            attention_layer._set_input_buffer(incremental_state, saved_state)

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Exemplo n.º 14
0
class LocalTransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, no_encoder_attn=False, num_layer=0):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            self.embed_dim, args.decoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim, args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False

        self.kernel_size = args.kernel_size
        self.padding_idx = 1

        self.use_local_decoder = args.use_local_decoder

        if type(self.kernel_size) == list:
            self.kernel_size = self.kernel_size[num_layer]
        

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self, x, encoder_out, encoder_padding_mask, incremental_state,
                prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None,
                self_attn_padding_mask=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """

        ############################# ADDED PART ####################################
        #For self attention

        if self.use_local_decoder:
            tgt_len, batch_size, embed_dim = x.size()

            size_to_add = self.kernel_size - tgt_len % self.kernel_size

            x2 = torch.zeros(tgt_len+size_to_add, batch_size, embed_dim, dtype=x.dtype, \
                device=x.device)
            x2[:tgt_len, :batch_size, :] = x
            x = x2.view(self.kernel_size, -1, embed_dim)

            if not self_attn_padding_mask:
                self_attn_padding_mask = torch.zeros(batch_size, tgt_len, dtype=torch.uint8, device=x.device)
            self_attn_padding_mask2 = torch.zeros(batch_size, tgt_len+size_to_add, dtype=encoder_padding_mask.dtype, device=encoder_padding_mask.device)
            self_attn_padding_mask2.fill_(1)
            self_attn_padding_mask2[:, :tgt_len] = self_attn_padding_mask
            self_attn_padding_mask = self_attn_padding_mask2.view(-1, self.kernel_size)

        ############################# END ADDED PART ###################################

        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        ############################# MODIFIED PART ####################################

        current_attn_mask = self_attn_mask

        if self.use_local_decoder:
            current_attn_mask = self.buffered_future_mask(x) if incremental_state is None else None

        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=current_attn_mask,
            )
        
        ############################# END MODIFIED PART ####################################


        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        ############################# ADDED PART ####################################
        if self.use_local_decoder:
            x2 = x.view(-1, batch_size, self.embed_dim)
            x = x2[:tgt_len, :, :]
        ############################# END ADDED PART ####################################

        attn = None
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
        
        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn

    #Normally not here, only in LocalTransformerDecoder
    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
        return self._future_mask[:dim, :dim]