Пример #1
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(
        self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
    ):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=True,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu")
        )
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self, input):
        """
        Args:
            input (Tuple):
                input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
                input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
                input[2] (ByteTensor/FloatTensor): encoder padding mask -
                    binary ByteTensor of shape `(batch, src_len)` where padding elements
                    are indicated by ``1``.
        Returns:
            output (Tuple):
                output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
                output[1] (ByteTensor/FloatTensor): encoder padding mask
                output[2] (LongTensor): previous decoder outputs
        """
        # Note: incremental state is not yet supported
        mt_task = False
        if isinstance(input, tuple):
            x = input[0]
            encoder_out = input[1]
            encoder_padding_mask = input[2]
            incremental_state = None
            mt_task = True
        else:
            x = input
            encoder_out = None
            encoder_padding_mask = None
            incremental_state = None

        if incremental_state is None:
            self_attn_mask = self.buffered_future_mask(x)
        else:
            self_attn_mask = None

        # TODO: add back prev_self_attn_state, prev_attn_state,
        # self_attn_padding_mask
        prev_self_attn_state = None
        prev_attn_state = None
        self_attn_padding_mask = None

        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)

        if mt_task:
            return (x, encoder_out, encoder_padding_mask)
        return x

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if (
            not hasattr(self, "_future_mask")
            or self._future_mask is None
            or self._future_mask.device != tensor.device
        ):
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
            )
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
            )
        return self._future_mask[:dim, :dim]

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Пример #2
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        if args.max_relative_length == -1:
            self.self_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
        else:
            self.self_attn = RelativeMultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                args.max_relative_length,
                dropout=args.attention_dropout,
                k_only=args.k_only,
            )

        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True
        self.decoder_position_dropout = args.decoder_position_dropout
        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self,
                x,
                encoder_out,
                encoder_padding_mask,
                incremental_state,
                prev_self_attn_state=None,
                prev_attn_state=None,
                self_attn_mask=None,
                self_attn_padding_mask=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        attn = None
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def position_dropout(self, x):
        if self.training and self.decoder_position_dropout != 0:
            position_mask = (torch.rand(x.size(0)) >
                             self.decoder_position_dropout).view(
                                 -1, 1, 1).cuda().half()
            x = x * position_mask
        return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Пример #3
0
class LightConvDecoderLayer(nn.Module):
    """Decoder layer block.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs.
            Default: ``False``
        kernel_size: kernel size of the convolution
    """
    def __init__(self, args, no_encoder_attn=False, kernel_size=0):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.conv_dim = args.decoder_conv_dim
        if args.decoder_glu:
            self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim)
            self.act = nn.GLU()
        else:
            self.linear1 = Linear(self.embed_dim, self.conv_dim)
            self.act = None
        if args.decoder_conv_type == "lightweight":
            self.conv = LightweightConv(
                self.conv_dim,
                kernel_size,
                padding_l=kernel_size - 1,
                weight_softmax=args.weight_softmax,
                num_heads=args.decoder_attention_heads,
                weight_dropout=args.weight_dropout,
            )
        elif args.decoder_conv_type == "dynamic":
            self.conv = DynamicConv(
                self.conv_dim,
                kernel_size,
                padding_l=kernel_size - 1,
                weight_softmax=args.weight_softmax,
                num_heads=args.decoder_attention_heads,
                weight_dropout=args.weight_dropout,
            )
        else:
            raise NotImplementedError
        self.linear2 = Linear(self.conv_dim, self.embed_dim)

        self.dropout_module = FairseqDropout(
            args.dropout, module_name=self.__class__.__name__)
        self.relu_dropout_module = FairseqDropout(
            args.relu_dropout, module_name=self.__class__.__name__)
        self.input_dropout_module = FairseqDropout(
            args.input_dropout, module_name=self.__class__.__name__)
        self.normalize_before = args.decoder_normalize_before

        self.conv_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

    def forward(
        self,
        x,
        encoder_out,
        encoder_padding_mask,
        incremental_state,
        prev_conv_state=None,
        prev_attn_state=None,
        conv_mask=None,
        conv_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(self.conv_layer_norm, x, before=True)
        if prev_conv_state is not None:
            if incremental_state is None:
                incremental_state = {}
            self.conv._set_input_buffer(incremental_state, prev_conv_state)
        x = self.input_dropout_module(x)
        x = self.linear1(x)
        if self.act is not None:
            x = self.act(x)
        x = self.conv(x, incremental_state=incremental_state)
        x = self.linear2(x)
        x = self.dropout_module(x)
        x = residual + x
        x = self.maybe_layer_norm(self.conv_layer_norm, x, after=True)

        attn = None
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = self.dropout_module(x)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = F.relu(self.fc1(x))
        x = self.relu_dropout_module(x)
        x = self.fc2(x)
        x = self.dropout_module(x)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn

    def extra_repr(self):
        return (
            "dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}"
            .format(
                self.dropout_module.p,
                self.relu_dropout_module.p,
                self.input_dropout_module.p,
                self.normalize_before,
            ))
Пример #4
0
class LocalTransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, no_encoder_attn=False, num_layer=0):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            self.embed_dim, args.decoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim, args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False

        self.kernel_size = args.kernel_size
        self.padding_idx = 1

        self.use_local_decoder = args.use_local_decoder

        if type(self.kernel_size) == list:
            self.kernel_size = self.kernel_size[num_layer]
        

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self, x, encoder_out, encoder_padding_mask, incremental_state,
                prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None,
                self_attn_padding_mask=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """

        ############################# ADDED PART ####################################
        #For self attention

        if self.use_local_decoder:
            tgt_len, batch_size, embed_dim = x.size()

            size_to_add = self.kernel_size - tgt_len % self.kernel_size

            x2 = torch.zeros(tgt_len+size_to_add, batch_size, embed_dim, dtype=x.dtype, \
                device=x.device)
            x2[:tgt_len, :batch_size, :] = x
            x = x2.view(self.kernel_size, -1, embed_dim)

            if not self_attn_padding_mask:
                self_attn_padding_mask = torch.zeros(batch_size, tgt_len, dtype=torch.uint8, device=x.device)
            self_attn_padding_mask2 = torch.zeros(batch_size, tgt_len+size_to_add, dtype=encoder_padding_mask.dtype, device=encoder_padding_mask.device)
            self_attn_padding_mask2.fill_(1)
            self_attn_padding_mask2[:, :tgt_len] = self_attn_padding_mask
            self_attn_padding_mask = self_attn_padding_mask2.view(-1, self.kernel_size)

        ############################# END ADDED PART ###################################

        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        ############################# MODIFIED PART ####################################

        current_attn_mask = self_attn_mask

        if self.use_local_decoder:
            current_attn_mask = self.buffered_future_mask(x) if incremental_state is None else None

        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=current_attn_mask,
            )
        
        ############################# END MODIFIED PART ####################################


        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        ############################# ADDED PART ####################################
        if self.use_local_decoder:
            x2 = x.view(-1, batch_size, self.embed_dim)
            x = x2[:tgt_len, :, :]
        ############################# END ADDED PART ####################################

        attn = None
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
        
        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn

    #Normally not here, only in LocalTransformerDecoder
    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
            self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
        return self._future_mask[:dim, :dim]
Пример #5
0
class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        embedding_dim: float = 768,
        ffn_embedding_dim: float = 3072,
        num_attention_heads: float = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        activation_fn: str = 'relu',
        add_bias_kv: bool = False,
        add_zero_attn: bool = False,
        export: bool = False,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.dropout = dropout
        self.activation_dropout = activation_dropout

        # Initialize blocks
        self.activation_fn = utils.get_activation_fn(activation_fn)
        self.self_attn = MultiheadAttention(self.embedding_dim,
                                            num_attention_heads,
                                            dropout=attention_dropout,
                                            add_bias_kv=add_bias_kv,
                                            add_zero_attn=add_zero_attn,
                                            self_attention=True)

        # layer norm associated with the self attention layer
        self.self_attn_layer_norm = LayerNorm(self.embedding_dim,
                                              export=export)

        self.encoder_attn = MultiheadAttention(
            self.embedding_dim,
            num_attention_heads,
            kdim=embedding_dim,
            vdim=embedding_dim,
            dropout=attention_dropout,
            encoder_decoder_attention=True,
        )
        self.encoder_attn_layer_norm = LayerNorm(self.embedding_dim,
                                                 export=export)

        self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
        self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)

        # layer norm associated with the position wise feed-forward NN
        self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
        self.need_attn = False

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        residual = x
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.self_attn_layer_norm(x)

        residual = x
        if prev_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.encoder_attn._set_input_buffer(incremental_state, saved_state)
        x, attn = self.encoder_attn(
            query=x,
            key=encoder_out,
            value=encoder_out,
            key_padding_mask=encoder_mask,
            incremental_state=incremental_state,
            static_kv=True,
            need_weights=(not self.training and self.need_attn),
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.encoder_attn_layer_norm(x)

        residual = x
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.final_layer_norm(x)
        return x, attn

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Пример #6
0
class transformer_with_copyDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Пример #7
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 self_attn_pattern=None,
                 encoder_attn_pattern=None,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        if self_attn_pattern is not None and args.PRUNE_DEC_SELF_ATTN:
            cpu_np_pattern = self_attn_pattern.cpu().numpy()
            d1_bounds, d2_bounds = find_bounds(cpu_np_pattern)
            prune_random = False
            try:
                if args.RANDOM_PRUNE:  #random prune
                    prune_random = True
                    self.self_attn_mask = torch.from_numpy(
                        random_mask(cpu_np_pattern, args.TAU))
                    if args.CUDA:
                        self.self_attn_mask = self.self_attn_mask.cuda()
            except:
                pass
            if not prune_random:
                cpu_np_pattern = cpu_np_pattern[:, :, 0:d1_bounds, 0:d2_bounds]
                target_percentile = args.TAU * 100
                threshold = np.percentile(cpu_np_pattern,
                                          target_percentile,
                                          interpolation='nearest')
                self.self_attn_mask = (self_attn_pattern <= threshold)
        else:
            self.self_attn_mask = None

        if encoder_attn_pattern is not None and args.PRUNE_ENC_DEC_ATTN:
            cpu_np_pattern = encoder_attn_pattern.cpu().numpy()
            d1_bounds, d2_bounds = find_bounds(cpu_np_pattern)
            prune_random = False
            try:
                if args.RANDOM_PRUNE:  #random prune
                    prune_random = True
                    self.encoder_attn_mask = torch.from_numpy(
                        random_mask(cpu_np_pattern, args.TAU))
                    if args.CUDA:
                        self.encoder_attn_mask = self.encoder_attn_mask.cuda()
            except:
                pass
            if not prune_random:
                cpu_np_pattern = cpu_np_pattern[:, :, 0:d1_bounds, 0:d2_bounds]
                target_percentile = args.TAU * 100
                threshold = np.percentile(cpu_np_pattern,
                                          target_percentile,
                                          interpolation='nearest')
                self.encoder_attn_mask = (encoder_attn_pattern <= threshold)
        else:
            self.encoder_attn_mask = None

        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention",
                                            False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
            args=args,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu"))
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
                args=args,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.self_attn._get_input_buffer(
            incremental_state)
        if self.cross_self_attention and not (
                incremental_state is not None and _self_attn_input_buffer
                is not None and "prev_key" in _self_attn_input_buffer):
            if self_attn_mask is not None:
                assert encoder_out is not None
                self_attn_mask = torch.cat((x.new_zeros(
                    x.size(0), encoder_out.size(0)), self_attn_mask),
                                           dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out.size(1), encoder_out.size(0))
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1)
            assert encoder_out is not None
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
            prune_attn_mask=self.self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            x, attn = self.encoder_attn(query=x,
                                        key=encoder_out,
                                        value=encoder_out,
                                        key_padding_mask=encoder_padding_mask,
                                        incremental_state=incremental_state,
                                        static_kv=True,
                                        need_weights=need_attn or
                                        (not self.training and self.need_attn),
                                        need_head_weights=need_head_weights,
                                        prune_attn_mask=self.encoder_attn_mask)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x,
                      p=float(self.activation_dropout),
                      training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:
                self_attn_state = [
                    saved_state["prev_key"], saved_state["prev_value"]
                ]
            return x, attn, self_attn_state
        return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn

    @torch.jit.export
    def reorder_incremental_state(
        self,
        incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
        new_order: Tensor,
    ):
        """Scriptable reorder incremental state in transformer layers."""
        self.self_attn.reorder_incremental_state(incremental_state, new_order)

        if self.encoder_attn is not None:
            self.encoder_attn.reorder_incremental_state(
                incremental_state, new_order)
Пример #8
0
class HybridRNNDecoder(FairseqIncrementalDecoder):
    """
    Decoder with general structure of Chen et al., The Best of Both Worlds:
    Combining Recent Advances in Neural Machine Translation, 2018.
    https://arxiv.org/abs/1804.09849
    """
    def _init_dims(self, args, src_dict, dst_dict, embed_tokens):
        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.embed_tokens = embed_tokens

        self.lstm_units = args.decoder_lstm_units
        self.num_layers = args.decoder_layers
        self.initial_input_dim = embed_dim

        self.encoder_output_dim = args.encoder_embed_dim
        if args.decoder_reduced_attention_dim is None:
            self.attention_dim = self.encoder_output_dim
        else:
            self.attention_dim = args.decoder_reduced_attention_dim
        self.input_dim = self.lstm_units + self.attention_dim

        self.num_attention_heads = args.decoder_attention_heads
        self.bottleneck_dim = args.decoder_out_embed_dim

    def _init_components(self, args, src_dict, dst_dict, embed_tokens):
        self.initial_rnn_layer = nn.LSTM(input_size=self.initial_input_dim,
                                         hidden_size=self.lstm_units)

        self.proj_encoder_layer = None
        if self.attention_dim != self.encoder_output_dim:
            self.proj_encoder_layer = fairseq_transformer.Linear(
                self.encoder_output_dim, self.attention_dim)

        self.proj_layer = None
        if self.lstm_units != self.attention_dim:
            self.proj_layer = fairseq_transformer.Linear(
                self.lstm_units, self.attention_dim)

        self.attention = MultiheadAttention(
            self.attention_dim,
            self.num_attention_heads,
            dropout=args.attention_dropout,
            encoder_decoder_attention=True,
        )

        self.extra_rnn_layers = nn.ModuleList([])
        for _ in range(self.num_layers - 1):
            self.extra_rnn_layers.append(
                nn.LSTM(input_size=self.input_dim,
                        hidden_size=self.lstm_units))

        self.bottleneck_layer = None
        if self.bottleneck_dim is not None:
            self.out_embed_dim = self.bottleneck_dim
            self.bottleneck_layer = fairseq_transformer.Linear(
                self.input_dim, self.out_embed_dim)
        else:
            self.out_embed_dim = self.input_dim

        self.embed_out = nn.Parameter(
            torch.Tensor(len(dst_dict), self.out_embed_dim))
        nn.init.normal_(self.embed_out, mean=0, std=self.out_embed_dim**-0.5)

        self.vocab_reduction_module = None
        if args.vocab_reduction_params:
            self.vocab_reduction_module = vocab_reduction.VocabReduction(
                src_dict,
                dst_dict,
                args.vocab_reduction_params,
                fp16=args.fp16)

        self.onnx_trace = False

    def __init__(self, args, src_dict, dst_dict, embed_tokens):
        super().__init__(dst_dict)
        self._init_dims(args, src_dict, dst_dict, embed_tokens)
        self._init_components(args, src_dict, dst_dict, embed_tokens)

    # Enable dependency injection by subclasses
    def _unpack_encoder_out(self, encoder_out):
        """Allow taking encoder_out from different architecture which
        may have different formats.
        """
        return encoder_out

    def _init_hidden(self, encoder_out, batch_size):
        """ Initialize with latent code if available otherwise zeros."""
        return torch.zeros([1, batch_size, self.lstm_units])

    def _concat_latent_code(self, x, encoder_out):
        """Concat latent code, if available in encoder_out, which is the
        case in subclass.
        """
        return x

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def _embed_prev_outputs(self, prev_output_tokens, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        return x, prev_output_tokens

    def forward(
        self,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        x, prev_output_tokens = self._embed_prev_outputs(
            prev_output_tokens=prev_output_tokens,
            incremental_state=incremental_state)
        return self._forward_given_embeddings(
            embed_out=x,
            prev_output_tokens=prev_output_tokens,
            encoder_out=encoder_out,
            incremental_state=incremental_state,
            possible_translation_tokens=possible_translation_tokens,
            timestep=timestep,
        )

    def _forward_given_embeddings(
        self,
        embed_out,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        x = embed_out
        (encoder_x, src_tokens,
         encoder_padding_mask) = self._unpack_encoder_out(encoder_out)
        bsz, seqlen = prev_output_tokens.size()

        state_outputs = []
        if incremental_state is not None:
            prev_states = utils.get_incremental_state(self, incremental_state,
                                                      "cached_state")
            if prev_states is None:
                prev_states = self._init_prev_states(encoder_out)

            # final 2 states of list are projected key and value
            saved_state = {
                "prev_key": prev_states[-2],
                "prev_value": prev_states[-1]
            }
            self.attention._set_input_buffer(incremental_state, saved_state)

        if incremental_state is not None:
            # first num_layers pairs of states are (prev_hidden, prev_cell)
            # for each layer
            h_prev = prev_states[0]
            c_prev = prev_states[1]
        else:
            h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
            c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

        x = self._concat_latent_code(x, encoder_out)
        x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev))
        if incremental_state is not None:
            state_outputs.extend([h_next, c_next])

        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.proj_encoder_layer is not None:
            encoder_x = self.proj_encoder_layer(encoder_x)

        attention_in = x
        if self.proj_layer is not None:
            attention_in = self.proj_layer(x)

        attention_out, attention_weights = self.attention(
            query=attention_in,
            key=encoder_x,
            value=encoder_x,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            need_weights=(not self.training),
        )

        for i, layer in enumerate(self.extra_rnn_layers):
            residual = x
            rnn_input = torch.cat([x, attention_out], dim=2)
            rnn_input = self._concat_latent_code(rnn_input, encoder_out)

            if incremental_state is not None:
                # first num_layers pairs of states are (prev_hidden, prev_cell)
                # for each layer
                h_prev = prev_states[2 * i + 2]
                c_prev = prev_states[2 * i + 3]
            else:
                h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
                c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

            x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev))
            if incremental_state is not None:
                state_outputs.extend([h_next, c_next])
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + residual

        x = torch.cat([x, attention_out], dim=2)
        x = self._concat_latent_code(x, encoder_out)
        if self.bottleneck_layer is not None:
            x = self.bottleneck_layer(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if (self.vocab_reduction_module is not None
                and possible_translation_tokens is None):
            decoder_input_tokens = prev_output_tokens.contiguous()
            possible_translation_tokens = self.vocab_reduction_module(
                src_tokens, decoder_input_tokens=decoder_input_tokens)

        output_weights = self.embed_out
        if possible_translation_tokens is not None:
            output_weights = output_weights.index_select(
                dim=0, index=possible_translation_tokens)

        logits = F.linear(x, output_weights)

        if incremental_state is not None:
            # encoder projections can be reused at each incremental step
            state_outputs.extend([prev_states[-2], prev_states[-1]])
            utils.set_incremental_state(self, incremental_state,
                                        "cached_state", state_outputs)

        return logits, attention_weights, possible_translation_tokens

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        return int(1e5)  # an arbitrary large number

    def _init_prev_states(self, encoder_out):
        """
        Initial (hidden, cell) values for LSTM layers are zero.

        For encoder-decoder attention, key and value are computed once from
        the encoder outputs and stay the same throughout decoding.
        """
        (encoder_x, src_tokens,
         encoder_padding_mask) = self._unpack_encoder_out(encoder_out)
        batch_size = torch.onnx.operators.shape_as_tensor(encoder_x)[1]

        if self.proj_encoder_layer is not None:
            encoder_x = self.proj_encoder_layer(encoder_x)

        states = []
        for _ in range(self.num_layers):
            hidden = self._init_hidden(encoder_out,
                                       batch_size).type_as(encoder_x)
            cell = torch.zeros([1, batch_size,
                                self.lstm_units]).type_as(encoder_x)
            states.extend([hidden, cell])

        # (key, value) for encoder-decoder attention computed from encoder
        # output and remain the same throughout decoding
        key = self.attention.k_proj(encoder_x)
        value = self.attention.v_proj(encoder_x)

        # (key, value) kept in shape (bsz, num_heads, seq_len, head_dim)
        # to avoid repeated transpose operations
        seq_len, batch_size_int, _ = encoder_x.shape
        num_heads = self.attention.num_heads
        head_dim = self.attention.head_dim
        key = (key.view(seq_len, batch_size_int * num_heads,
                        head_dim).transpose(0,
                                            1).view(batch_size_int, num_heads,
                                                    seq_len, head_dim))
        value = (value.view(seq_len, batch_size_int * num_heads,
                            head_dim).transpose(0, 1).view(
                                batch_size_int, num_heads, seq_len, head_dim))
        states.extend([key, value])

        return states

    def reorder_incremental_state(self, incremental_state, new_order):
        # parent reorders attention model
        super().reorder_incremental_state(incremental_state, new_order)

        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   "cached_state")
        if cached_state is None:
            return

        # Last 2 elements of prev_states are encoder projections
        # used for ONNX export
        for i, state in enumerate(cached_state[:-2]):
            cached_state[i] = state.index_select(1, new_order)

        utils.set_incremental_state(self, incremental_state, "cached_state",
                                    cached_state)
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, index, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        kernel_size = args.decoder_kernel_size_list[index]


        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu')
        )
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        if args.decoder_branch_type is None:
            self.self_attn = MultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=True,
            )
        else:
            layers = []
            embed_dims = []
            heads = []
            num_types = len(args.decoder_branch_type)
            for layer_type in args.decoder_branch_type:
                embed_dims.append(int(layer_type.split(':')[2]))
                heads.append(int(layer_type.split(':')[3]))
                layers.append(self.get_layer(args, index, embed_dims[-1], heads[-1], layer_type, add_bias_kv, add_zero_attn))
            assert sum(embed_dims) == self.embed_dim, (sum(embed_dims), self.embed_dim)

            self.self_attn = MultiBranch(layers, embed_dims)
        

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )

            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim, init=args.ffn_init)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim, init=args.ffn_init)
        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.need_attn = True

        self.onnx_trace = False
    
    def get_layer(self, args, index, out_dim, num_heads, layer_type, add_bias_kv, add_zero_attn):
        kernel_size = layer_type.split(':')[1]
        if kernel_size == 'default':
            kernel_size = args.decoder_kernel_size_list[index]
        else:
            kernel_size = int(kernel_size)
        layer_type = layer_type.split(':')[0]
        if layer_type == 'lightweight':
            layer = LightweightConv(
                out_dim, kernel_size, padding_l=kernel_size-1,
                weight_softmax=args.weight_softmax, num_heads=num_heads,
                weight_dropout=args.weight_dropout, with_linear=args.conv_linear,
            )
        elif layer_type == 'dynamic':
            layer = DynamicConv(
                out_dim, kernel_size, padding_l=kernel_size-1,
                weight_softmax=args.weight_softmax, num_heads=num_heads,
                weight_dropout=args.weight_dropout, with_linear=args.conv_linear,
                glu=args.decoder_glu,
            )
        elif layer_type == 'attn':
            layer = MultiheadAttention(
                embed_dim=out_dim,
                num_heads=num_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=True,
            )
        else:
            raise NotImplementedError
        return layer

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention",
                                            False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu"))
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
        encoder_out2=None,
        balance_weight=None,
        encoder_padding_mask2=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        #print(encoder_out2)
        #print(balance_weight)

        attn2 = None

        if need_head_weights:
            need_attn = True

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            print("prev_self_attn_state not None")
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        _self_attn_input_buffer = self.self_attn._get_input_buffer(
            incremental_state)

        if self.cross_self_attention and not (
                incremental_state is not None and _self_attn_input_buffer
                is not None and "prev_key" in _self_attn_input_buffer):
            if self_attn_mask is not None:
                assert encoder_out is not None
                self_attn_mask = torch.cat((x.new_zeros(
                    x.size(0), encoder_out.size(0)), self_attn_mask),
                                           dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out.size(1), encoder_out.size(0))
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1)
            assert encoder_out is not None
            y = torch.cat((encoder_out, x), dim=0)

            #print("Here cross self")
        else:
            #print("Here not cross self")
            y = x

        #print("self_attn")

        #print('input x', x)

        x, attn, _ = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        #print("end self_attn")

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        #print('output x', x)

        if self.encoder_attn is not None:
            #print("Not None")
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

            if prev_attn_state is not None:

                #print("prev_attn_state not None")

                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            #TODO: CJA

            #print("------", encoder_out.shape, x.shape)

            #print('input', x)
            #print('encoder_out', encoder_out)

            if encoder_out2 is not None:
                if balance_weight is not None:
                    #print("need_head_weights", need_head_weights)
                    #print("Incremental_state: ",incremental_state )
                    if need_head_weights:

                        x, attn, attn2 = self.encoder_attn(
                            query=x,
                            key=encoder_out,
                            value=encoder_out,
                            key2=encoder_out2,
                            value2=encoder_out2,
                            balance_weight=balance_weight,
                            key_padding_mask=encoder_padding_mask,
                            key_padding_mask2=encoder_padding_mask2,
                            incremental_state=incremental_state,
                            static_kv=True,
                            need_weights=need_attn
                            or (not self.training and self.need_attn),
                            need_head_weights=need_head_weights,
                        )
                        #print('output', x)

                    else:

                        #print('here')
                        #print('incremental_state', incremental_state)

                        x, attn, _ = self.encoder_attn(
                            query=x,
                            key=encoder_out,
                            value=encoder_out,
                            key2=encoder_out2,
                            value2=encoder_out2,
                            balance_weight=balance_weight,
                            key_padding_mask=encoder_padding_mask,
                            key_padding_mask2=encoder_padding_mask2,
                            incremental_state=incremental_state,
                            static_kv=True,
                            need_weights=need_attn
                            or (not self.training and self.need_attn),
                            need_head_weights=need_head_weights,
                        )

                else:

                    #print("Incremental_state: ",incremental_state )

                    #print(encoder_out.shape)
                    #print(encoder_out2.shape)

                    concat_out = torch.cat([encoder_out, encoder_out2], dim=0)
                    #print(".......")
                    #print(encoder_out.shape, encoder_out2.shape)
                    #print(concat_out.shape)
                    #print("1: ", encoder_padding_mask.shape)
                    #print("2: ",encoder_padding_mask2.shape)

                    encoder_padding = torch.cat(
                        [encoder_padding_mask, encoder_padding_mask2], dim=1)
                    #print(encoder_padding_mask[0], encoder_padding_mask2[0])
                    #print(encoder_padding.shape)

                    x, attn, _ = self.encoder_attn(
                        query=x,
                        key=concat_out,
                        value=concat_out,
                        key_padding_mask=encoder_padding,
                        incremental_state=incremental_state,
                        static_kv=True,
                        need_weights=need_attn
                        or (not self.training and self.need_attn),
                        need_head_weights=need_head_weights,
                    )

            else:
                x, attn, _ = self.encoder_attn(
                    query=x,
                    key=encoder_out,
                    value=encoder_out,
                    key_padding_mask=encoder_padding_mask,
                    incremental_state=incremental_state,
                    static_kv=True,
                    need_weights=need_attn
                    or (not self.training and self.need_attn),
                    need_head_weights=need_head_weights,
                )

                #print("!!!!!!!", x.shape)

            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x,
                      p=float(self.activation_dropout),
                      training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x)

        if self.onnx_trace and incremental_state is not None:

            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:

                #print("here")

                self_attn_state = [
                    saved_state["prev_key"], saved_state["prev_value"]
                ]
            return x, attn, self_attn_state

        if attn2 is not None:
            return x, attn, attn2
        else:
            return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn
Пример #11
0
class MaskDecoderLayer(nn.Module):
    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            self.embed_dim,
            args.decoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.source_encoder_attn = None
            self.mask_encoder_attn = None
            self.encoder_attn_layer_norm = None
            self.concat_dense = None
        else:
            self.source_encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.mask_encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
            self.concat_dense = Linear(2 * self.embed_dim,
                                       self.embed_dim,
                                       bias=True)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self,
                x,
                source_encoder_out,
                source_encoder_padding_mask,
                mask_encoder_out,
                mask_encoder_padding_mask,
                incremental_state,
                prev_self_attn_state=None,
                prev_source_attn_state=None,
                prev_mask_attn_state=None,
                self_attn_mask=None,
                self_attn_padding_mask=None):
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        attn_source = None
        attn_mask = None
        if self.source_encoder_attn is not None:
            residual = x
            source_x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                             x,
                                             before=True)
            mask_x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                           x,
                                           before=True)

            self.set_attention_input_buffer(self.source_encoder_attn,
                                            incremental_state,
                                            prev_source_attn_state)
            self.set_attention_input_buffer(self.mask_encoder_attn,
                                            incremental_state,
                                            prev_mask_attn_state)

            source_x, attn_source = self.source_encoder_attn(
                query=source_x,
                key=source_encoder_out,
                value=source_encoder_out,
                key_padding_mask=source_encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )

            mask_x, attn_mask = self.mask_encoder_attn(
                query=mask_x,
                key=mask_encoder_out,
                value=mask_encoder_out,
                key_padding_mask=mask_encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = torch.cat([source_x, mask_x], dim=-1)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = F.relu(self.concat_dense(x))
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn_source, attn_mask, self_attn_state
        return x, attn_source, attn_mask

    def set_attention_input_buffer(self, attention_layer, incremental_state,
                                   previous_attn_state):
        if previous_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = previous_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            attention_layer._set_input_buffer(incremental_state, saved_state)

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Пример #12
0
class TransformerDecoderLayer(nn.Module):
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, 'cross_self_attention',
                                            False)
        if args.div or args.entmax:
            self.self_attn = ConstrainedMultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                args=args,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=not self.cross_self_attention,
                cur_attn_type='ds')
        else:
            self.self_attn = MultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=not self.cross_self_attention,
            )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:

            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            if args.div or args.entmax:
                self.self_attn = ConstrainedMultiheadAttention(
                    embed_dim=self.embed_dim,
                    num_heads=args.decoder_attention_heads,
                    args=args,
                    kdim=getattr(args, 'encoder_embed_dim', None),
                    vdim=getattr(args, 'encoder_embed_dim', None),
                    dropout=args.attention_dropout,
                    encoder_decoder_attention=True,
                    cur_attn_type='ds',
                )
            else:
                self.encoder_attn = MultiheadAttention(
                    self.embed_dim,
                    args.decoder_attention_heads,
                    kdim=getattr(args, 'encoder_embed_dim', None),
                    vdim=getattr(args, 'encoder_embed_dim', None),
                    dropout=args.attention_dropout,
                    encoder_decoder_attention=True,
                )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
        need_attn=False,
        need_head_weights=False,
    ):

        if need_head_weights:
            need_attn = True

        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        if self.cross_self_attention and not (
                incremental_state is not None and "prev_key"
                in self.self_attn._get_input_buffer(incremental_state)):
            if self_attn_mask is not None:
                self_attn_mask = torch.cat((x.new(
                    x.size(0), encoder_out.size(0)).zero_(), self_attn_mask),
                                           dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    encoder_padding_mask = self_attn_padding_mask.new(
                        encoder_out.size(1), encoder_out.size(0)).zero_()
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1)
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            if self_attn_padding_mask is not None:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"], saved_state["prev_key_padding_mask"]
            else:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Пример #13
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(
        self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
    ):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        embed_dim = self.embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention", False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
#        self.dropout = [0.05, 0.1, 0.25, 0.3]
#        self.dropout = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.3]
        self.dropout = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu")
        )
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
#        self.self_attn_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])
        self.self_attn_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim])

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])

#        self.fc1 = SLinear([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim], 
#                           [int(args.decoder_ffn_embed_dim / 4), int(args.decoder_ffn_embed_dim * 2 / 4),int(args.decoder_ffn_embed_dim * 3 / 4), args.decoder_ffn_embed_dim])
#        self.fc2 = SLinear([int(args.decoder_ffn_embed_dim / 4), int(args.decoder_ffn_embed_dim * 2 / 4), int(args.decoder_ffn_embed_dim * 3 / 4), args.decoder_ffn_embed_dim], 
#                           [int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])

#        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
#        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

#        self.final_layer_norm = SlimmableLayernorm([int(self.embed_dim / 4), int(self.embed_dim * 2 / 4), int(self.embed_dim * 3 / 4), self.embed_dim])
            self.encoder_attn_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim])


        self.fc1 = SLinear(embed_dim, args.encoder_ffn_embed_dim)
        self.fc2 = SLinear(args.encoder_ffn_embed_dim, embed_dim)

        self.final_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim])


        self.final_layer_norm = SlimmableLayernorm([int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim])

        self.linear_list = [int(embed_dim * 4 / 16), int(embed_dim * 5 / 16), int(embed_dim * 6 / 16),  int(embed_dim * 7 / 16),  int(embed_dim * 8 / 16), int(embed_dim * 9 / 16), int(embed_dim * 10 / 16), int(embed_dim * 11 / 16), int(embed_dim * 12 / 16), int(embed_dim * 13 / 16),  int(embed_dim * 14 / 16), int(embed_dim * 15 / 16), embed_dim]

        self.ffn_list = [int(args.encoder_ffn_embed_dim * 4 / 16), int(args.encoder_ffn_embed_dim * 5 / 16), int(args.encoder_ffn_embed_dim * 6 / 16), int(args.encoder_ffn_embed_dim * 7 / 16), int(args.encoder_ffn_embed_dim * 8 / 16), int(args.encoder_ffn_embed_dim * 9 / 16), int(args.encoder_ffn_embed_dim * 10 / 16), int(args.encoder_ffn_embed_dim * 11 / 16), int(args.encoder_ffn_embed_dim * 12 / 16), int(args.encoder_ffn_embed_dim * 13 / 16),  int(args.encoder_ffn_embed_dim * 14 / 16), int(args.encoder_ffn_embed_dim * 15 / 16), args.encoder_ffn_embed_dim]



        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        idx,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
#        pdb.set_trace()
        if need_head_weights:
            need_attn = True

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x, index[0])
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
        if self.cross_self_attention and not (
            incremental_state is not None
            and _self_attn_input_buffer is not None
            and "prev_key" in _self_attn_input_buffer
        ):
            if self_attn_mask is not None:
                assert encoder_out is not None
                self_attn_mask = torch.cat(
                    (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
                )
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out.size(1), encoder_out.size(0)
                    )
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1
                )
            assert encoder_out is not None
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            index=idx,
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout[idx[0]], training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x, idx[0])

        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x, idx[0])
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x, attn = self.encoder_attn(
                index=idx,
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout[idx[0]], training=self.training)
            x = residual + x
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x, idx[0])

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x, idx[0])
        x = self.activation_fn(self.fc1(x, self.linear_list[idx[0]], self.ffn_list[idx[1]]))
        x = F.dropout(x, p=self.dropout[idx[1]], training=self.training)
        x = self.fc2(x, self.ffn_list[idx[1]], self.linear_list[idx[0]])
        x = F.dropout(x, p=self.dropout[idx[0]], training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x, idx[0])
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:
                self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
            return x, attn, self_attn_state
        return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn
Пример #14
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(
        self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
    ):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention", False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu")
        )
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
        if self.cross_self_attention and not (
            incremental_state is not None
            and _self_attn_input_buffer is not None
            and "prev_key" in _self_attn_input_buffer
        ):
            if self_attn_mask is not None:
                assert encoder_out is not None
                self_attn_mask = torch.cat(
                    (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
                )
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out.size(1), encoder_out.size(0)
                    )
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1
                )
            assert encoder_out is not None
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=float(self.activation_dropout), training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:
                self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
            return x, attn, self_attn_state
        return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn

    def compute_macs_params(self, T=1, S=1):
        macs = 0
        n_params = 0
        macs_attn = 0

        # LayerNorm
        n_params += sum([p.numel() for p in self.self_attn_layer_norm.parameters()])
        n_params += sum([p.numel() for p in self.final_layer_norm.parameters()])

        # self attention
        self_attn_layer = self.self_attn.compute_macs_params(T=T, S=T)
        macs += self_attn_layer['macs']
        n_params += self_attn_layer['params']
        macs_attn += self_attn_layer['macs_attn']

        # Encoder-decoder attn
        if self.encoder_attn is not None:
            # self attention scaled-dot-product Attn
            enc_attn = self.encoder_attn.compute_macs_params(T=T, S=S)
            macs += enc_attn['macs']
            n_params += enc_attn['params']
            macs_attn += enc_attn['macs_attn']

        # FFN
        fc1_params = sum([p.numel() for p in self.fc1.parameters()])
        macs += (fc1_params * T)
        n_params += (fc1_params)

        fc2_params = sum([p.numel() for p in self.fc2.parameters()])
        macs += (fc2_params * T)
        n_params += fc2_params

        return {
            'name': self.__class__.__name__,
            'macs': macs,
            'params': n_params,
            'macs_attn': macs_attn
        }
Пример #15
0
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, no_encoder_attn=False, copyNet=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            self.embed_dim, args.encoder_attention_heads,
            dropout=args.attention_dropout,
        )
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim, args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False
        #
        self.copyNet = copyNet
        if self.copyNet:
            self.target_encoder_attn = MultiheadAttention(
                    self.embed_dim, args.decoder_attention_heads,
                    dropout=args.attention_dropout,
                )
            self.incorpor_weights_W = Linear(args.decoder_embed_dim, 1)
            self.incorpor_weights_U = Linear(args.decoder_embed_dim, 1)
        else:
            self.target_encoder_attn = None
            self.incorpor_weights_W = None
            self.incorpor_weights_U = None

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self, x, encoder_out, encoder_padding_mask, incremental_state,
                prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None,
                self_attn_padding_mask=None, TM=None, TM_padding=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        source_attn = None
        retrieve_attn = None
        p_copy = None
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)
            x, source_attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            if self.target_encoder_attn is not None:
                assert TM.size(0) > 1 and TM.size(1) == x.size(1) and TM.size(2) == x.size(2), "TM: {}, x: {}".format(TM.size(), x.size())
                assert TM.size(0) ==TM_padding.size(1) and TM.size(1) == TM_padding.size(0), "TM: {}, TM_padding: {}".format(TM.size(), TM_padding.size())
                target_x, retrieve_attn = self.target_encoder_attn(
                    query=x,
                    key=TM,
                    value=TM,
                    key_padding_mask=TM_padding,
                    incremental_state=incremental_state,
                    static_kv=True,
                    need_weights=True,
                )
                p_copy = torch.sigmoid(self.incorpor_weights_W(x) + self.incorpor_weights_U(target_x))
                p_copy = p_copy.transpose(0, 1) # T x B x 1 -> B x T x 1
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
            return x, source_attn, self_attn_state
        return x, source_attn, retrieve_attn, p_copy





    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, 'cross_self_attention', False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        self.dropout = args.dropout
        bi_context_attn = getattr(args, 'input_form', None)
        self.bi_context_attn = (bi_context_attn == 'sep')
        self.share_key_proj = getattr(args, 'sep_attn_share_key_proj', False)

        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu')
        )
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            if not self.bi_context_attn:
                self.encoder_attn = MultiheadAttention(
                    self.embed_dim,
                    args.decoder_attention_heads,
                    kdim=getattr(args, 'encoder_embed_dim', None),
                    vdim=getattr(args, 'encoder_embed_dim', None),
                    dropout=args.attention_dropout,
                    encoder_decoder_attention=True,
                )
            else:
                self.encoder_attn = MultiheadAttention(
                    self.embed_dim,
                    args.decoder_attention_heads,
                    kdim=getattr(args, 'encoder_embed_dim', None),
                    vdim=getattr(args, 'encoder_embed_dim', None),
                    dropout=args.attention_dropout,
                    encoder_decoder_attention=True,
                    qkv_same_dim=not self.share_key_proj,
                )
                # share key proj is query actually
                self.aug_encoder_attn = MultiheadAttention(
                    self.embed_dim,
                    args.decoder_attention_heads,
                    kdim=getattr(args, 'encoder_embed_dim', None),
                    vdim=getattr(args, 'encoder_embed_dim', None),
                    dropout=args.attention_dropout,
                    encoder_decoder_attention=True,
                    shared_q_proj_weight=self.encoder_attn.q_proj_weight if self.share_key_proj else None,
                    qkv_same_dim=not self.share_key_proj,
                )
                self.context_value_weight = getattr(args, 'ctx_value_weight', 0.5)
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
        need_attn=False,
        need_head_weights=False,
        bi_context=None,
        bi_context_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            self.self_attn._set_input_buffer(incremental_state, saved_state)

        if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)):
            if self_attn_mask is not None:
                self_attn_mask = torch.cat((x.new(x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    encoder_padding_mask = self_attn_padding_mask.new(encoder_out.size(1), encoder_out.size(0)).zero_()
                self_attn_padding_mask = torch.cat((encoder_padding_mask, self_attn_padding_mask), dim=1)
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )

            if self.bi_context_attn:
                bx, battn = self.aug_encoder_attn(
                    query=x,
                    key=bi_context,
                    value=bi_context,
                    key_padding_mask=bi_context_padding_mask,
                    incremental_state=incremental_state,
                    static_kv=True,
                    need_weights=need_attn or (not self.training and self.need_attn),
                    need_head_weights=need_head_weights,
                )
                x = (1. - self.context_value_weight) * x + self.context_value_weight * bx
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            if self_attn_padding_mask is not None:
                self_attn_state = saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"]
            else:
                self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Пример #17
0
class TarcTransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(
        self, args, no_encoder_attn=False, num_cross_attentions=0, add_bias_kv=False, add_zero_attn=False
    ):
        super().__init__()
        self.num_cross_attentions = num_cross_attentions
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention", False)
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu")
        )
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        # This is my main modification: cross-attentions to attend the other decoder outputs
        self.cross_attentions = nn.ModuleList()
        self.cross_attentions_norm = nn.ModuleList()
        for i in range( num_cross_attentions ):
            self.cross_attentions.append(
                MultiheadAttention(
                    self.embed_dim,
                    args.decoder_attention_heads,
                    kdim=self.embed_dim,
                    vdim=self.embed_dim,
                    dropout=args.attention_dropout,
                    encoder_decoder_attention=True,
                )
            )
            self.cross_attentions_norm.append(
                LayerNorm(self.embed_dim, export=export)
            )

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out: Optional[List[torch.Tensor]] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        prev_cross_attn_state: Optional[List[List[torch.Tensor]]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        assert len(self.cross_attentions)+1 == len(encoder_out)

        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
        if self.cross_self_attention and not (
            incremental_state is not None
            and _self_attn_input_buffer is not None
            and "prev_key" in _self_attn_input_buffer
        ):
            if self_attn_mask is not None:
                assert encoder_out[0] is not None
                self_attn_mask = torch.cat(
                    (x.new_zeros(x.size(0), encoder_out[0].size(0)), self_attn_mask), dim=1
                )
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out[0] is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out[0].size(1), encoder_out[0].size(0)
                    )
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1
                )
            assert encoder_out[0] is not None
            y = torch.cat((encoder_out[0], x), dim=0)
        else:
            y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        cross_attn_x = x
        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out[0],
                value=encoder_out[0],
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        if self.num_cross_attentions > 0:
            residual = cross_attn_x
            all_att_output = torch.zeros_like(cross_attn_x)
            if self.normalize_before:
                cross_attn_x = self.cross_attentions_norm[0](cross_attn_x)
            for i in range( len(self.cross_attentions) ):
                if prev_cross_attn_state is not None:
                    prev_key, prev_value = prev_cross_attn_state[i][:2]
                    cross_saved_state: Dict[str, Optional[Tensor]] = {
                        "prev_key": prev_key,
                        "prev_value": prev_value,
                    }
                    if len(prev_cross_attn_state[i]) >= 3:
                        cross_saved_state["prev_key_padding_mask"] = prev_cross_attn_state[i][2]
                    assert incremental_state is not None
                    self.cross_attentions[i]._set_input_buffer(incremental_state, cross_saved_state)

                att_output, attn = self.cross_attentions[i](
                    query=cross_attn_x,
                    key=encoder_out[i+1],
                    value=encoder_out[i+1],
                    key_padding_mask=None,
                    incremental_state=incremental_state,
                    static_kv=True,
                    need_weights=need_attn or (not self.training and self.need_attn),
                    need_head_weights=need_head_weights,
                )
                att_output = F.dropout(att_output, p=self.dropout, training=self.training) 
                all_att_output = att_output + all_att_output
            if self.encoder_attn is not None:
                x = x + all_att_output  # encoder_attn and cross_attentions use the same residual, so no need to add it twice
            else:
                x = residual + x + all_att_output
            if not self.normalize_before:
                x = self.cross_attentions_norm[0](x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=float(self.activation_dropout), training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before:
            x = self.final_layer_norm(x)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            assert saved_state is not None
            if self_attn_padding_mask is not None:
                self_attn_state = [
                    saved_state["prev_key"],
                    saved_state["prev_value"],
                    saved_state["prev_key_padding_mask"],
                ]
            else:
                self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
            return x, attn, self_attn_state
        return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn

    @torch.jit.export
    def reorder_incremental_state(
        self,
        incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
        new_order: Tensor,
    ):
        """Scriptable reorder incremental state in transformer layers."""
        self.self_attn.reorder_incremental_state(incremental_state, new_order)

        if self.encoder_attn is not None:
            self.encoder_attn.reorder_incremental_state(incremental_state, new_order)

        if self.num_cross_attentions > 0:
            [attn.reorder_incremental_state(incremental_state, new_order) for attn in self.cross_attentions]
Пример #18
0
class AttentionDecoderLayer(nn.Module):
    def __init__(
        self,
        embed_dim,
        attention_heads,
        self_attention=True,
        add_bias_kv=False,
        add_zero_attn=False,
        attention_dropout=0.1,
        dropout=0.3,
        normalize_before=False,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=attention_heads,
            dropout=attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=True) if self_attention else None
        self.dropout = dropout
        self.normalize_before = normalize_before
        self.encoder_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=attention_heads,
            kdim=None,
            vdim=None,
            dropout=attention_dropout,
            encoder_decoder_attention=True) if not self_attention else None
        self.attn_layer_norm = LayerNorm(self.embed_dim, export=False)

    def forward(self,
                x,
                encoder_out=None,
                incremental_state=None,
                encoder_padding_mask=None,
                prev_self_attn_state=None,
                prev_attn_state=None,
                self_attn_mask=None,
                self_attn_padding_mask=None,
                need_attn=False,
                need_head_weights=False,
                **unused):
        # encoder_out = None
        # encoder_padding_mask = None
        # if encoder_outs is not None:
        #     if 'encoder_out' in encoder_outs.keys():
        #         encoder_out = encoder_outs['encoder_out']
        #     if 'encoder_padding_mask' in encoder_outs.keys():
        #         encoder_padding_mask = encoder_outs['encoder_padding_mask']

        residual = x
        x = self.maybe_layer_norm(x, before=True)

        if self.self_attn is not None:
            if prev_self_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_self_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_self_attn_state) >= 3:
                    saved_state[
                        "prev_key_padding_mask"] = prev_self_attn_state[2]
                self.self_attn._set_input_buffer(incremental_state,
                                                 saved_state)

            x, _ = self.self_attn(query=x,
                                  key=x,
                                  value=x,
                                  key_padding_mask=self_attn_padding_mask,
                                  incremental_state=incremental_state,
                                  need_weights=False,
                                  attn_mask=self_attn_mask)

        if self.encoder_attn is not None:
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            x, _ = self.encoder_attn(query=x,
                                     key=encoder_out,
                                     value=encoder_out,
                                     key_padding_mask=encoder_padding_mask,
                                     incremental_state=incremental_state,
                                     static_kv=True,
                                     need_weights=False,
                                     need_head_weights=False)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(x, after=True)

        return x, None

    def maybe_layer_norm(self, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return self.attn_layer_norm(x)
        else:
            return x


# import torch
# device = torch.device("cuda:0")
# x = torch.rand(4, 2, 8).to(device)
# encoder = AttentionDecoderLayer(8, 4).to(device)
# mask = torch.zeros(2, 4).bool().to(device)
# y = encoder(x, incremental_state=None)
# print(y.size())
Пример #19
0
Файл: bgt.py Проект: jwcmu/bgt
class TransformerSentenceEmbeddingDecoderLayer(TransformerDecoderLayer):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 add_bias_kv=False,
                 add_zero_attn=False,
                 do_trans=True):
        super().__init__(args)
        self.embed_dim = args.decoder_embed_dim
        self.args = args
        self.do_trans = do_trans
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=True)
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, 'char_inputs', False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

        if do_trans:
            self.decoder_fc1 = Linear(self.embed_dim + self.args.latent_size,
                                      self.embed_dim)
        else:
            self.decoder_fc1 = Linear(
                self.embed_dim + self.args.latent_size * 2, self.embed_dim)

    def forward(
        self,
        x,
        sent_emb=None,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
        size = (x.size()[0], x.size()[1], sent_emb.size()[-1])

        concat_sent_emb = torch.cat((x, sent_emb.expand(size)), dim=2)
        x = self.decoder_fc1(concat_sent_emb)
        F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn
Пример #20
0
class HybridRNNDecoder(FairseqIncrementalDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed

        embed_dim = embed_tokens.embedding_dim
        self.embed_tokens = embed_tokens

        self.lstm_units = args.decoder_lstm_units
        self.num_layers = args.decoder_layers
        self.initial_input_dim = embed_dim

        self.encoder_output_dim = args.encoder_embed_dim
        if args.decoder_reduced_attention_dim is None:
            self.attention_dim = self.encoder_output_dim
        else:
            self.attention_dim = args.decoder_reduced_attention_dim
        self.input_dim = self.lstm_units + self.attention_dim

        self.num_attention_heads = args.decoder_attention_heads
        self.bottleneck_dim = args.decoder_out_embed_dim


        self.initial_rnn_layer = nn.LSTM(
            input_size=self.initial_input_dim, hidden_size=self.lstm_units
        )
        self.initial_layernorm = LayerNorm(self.lstm_units)

        self.proj_encoder_layer = None
        if self.attention_dim != self.encoder_output_dim:
            self.proj_encoder_layer = Linear(
                self.encoder_output_dim, self.attention_dim
            )

        self.proj_layer = None
        if self.lstm_units != self.attention_dim:
            self.proj_layer = Linear(
                self.lstm_units, self.attention_dim
            )

        self.attention = MultiheadAttention(
            self.attention_dim,
            self.num_attention_heads,
            dropout=args.attention_dropout,
            encoder_decoder_attention=True,
        )

        self.extra_rnn_layers = nn.ModuleList([])
        self.extra_layernorms = nn.ModuleList([])
        for _ in range(self.num_layers - 1):
            self.extra_rnn_layers.append(
                nn.LSTM(input_size=self.input_dim, hidden_size=self.lstm_units)
            )
            self.extra_layernorms.append(
                LayerNorm(self.lstm_units)
            )

        self.bottleneck_layer = None
        if self.bottleneck_dim is not None:
            self.out_embed_dim = self.bottleneck_dim
            self.bottleneck_layer = Linear(
                self.input_dim, self.out_embed_dim
            )
        else:
            self.out_embed_dim = self.input_dim

        if not self.share_input_output_embed:
            self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.out_embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=self.out_embed_dim ** -0.5)
        else:
            assert self.bottleneck_dim == args.decoder_embed_dim, (self.bottleneck_dim, args.decoder_embed_dim)

    def _unpack_encoder_out(self, encoder_out):
        """ Allow taking encoder_out from different architecture which
        may have different formats.
        """
        # return encoder_out['encoder_out'], encoder_out['encoder_padding_mask']
        return encoder_out.encoder_out, encoder_out.encoder_padding_mask

    def _init_hidden(self, encoder_out, batch_size):
        """ Initialize with latent code if available otherwise zeros."""
        return torch.zeros([1, batch_size, self.lstm_units])

    def _concat_latent_code(self, x, encoder_out):
        """ Concat latent code, if available in encoder_out, which is the
        case in subclass.
        """
        return x

    def _embed_prev_outputs(self, prev_output_tokens, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        return x, prev_output_tokens

    def forward(
        self,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        x, prev_output_tokens = self._embed_prev_outputs(
            prev_output_tokens=prev_output_tokens, incremental_state=incremental_state
        )
        return self._forward_given_embeddings(
            embed_out=x,
            prev_output_tokens=prev_output_tokens,
            encoder_out=encoder_out,
            incremental_state=incremental_state,
            possible_translation_tokens=possible_translation_tokens,
            timestep=timestep,
        )

    def _forward_given_embeddings(
        self,
        embed_out,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        x = embed_out
        (encoder_x, encoder_padding_mask) = self._unpack_encoder_out(encoder_out)
        bsz, seqlen = prev_output_tokens.size()

        state_outputs = []
        if incremental_state is not None:
            prev_states = utils.get_incremental_state(
                self, incremental_state, "cached_state"
            )
            if prev_states is None:
                prev_states = self._init_prev_states(encoder_out)

            # final 2 states of list are projected key and value
            saved_state = {"prev_key": prev_states[-2], "prev_value": prev_states[-1]}
            self.attention._set_input_buffer(incremental_state, saved_state)

        if incremental_state is not None:
            # first num_layers pairs of states are (prev_hidden, prev_cell)
            # for each layer
            h_prev = prev_states[0]
            c_prev = prev_states[1]
        else:
            h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
            c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

        x = self._concat_latent_code(x, encoder_out)
        x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev))
        x = self.initial_layernorm(x)
        if incremental_state is not None:
            state_outputs.extend([h_next, c_next])

        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.proj_encoder_layer is not None:
            encoder_x = self.proj_encoder_layer(encoder_x)

        attention_in = x
        if self.proj_layer is not None:
            attention_in = self.proj_layer(x)

        attention_out, attention_weights = self.attention(
            query=attention_in,
            key=encoder_x,
            value=encoder_x,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            need_weights=(not self.training),
        )

        for i, layer in enumerate(self.extra_rnn_layers):
            residual = x
            rnn_input = torch.cat([x, attention_out], dim=2)
            rnn_input = self._concat_latent_code(rnn_input, encoder_out)

            if incremental_state is not None:
                # first num_layers pairs of states are (prev_hidden, prev_cell)
                # for each layer
                h_prev = prev_states[2 * i + 2]
                c_prev = prev_states[2 * i + 3]
            else:
                h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
                c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

            x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev))
            if incremental_state is not None:
                state_outputs.extend([h_next, c_next])
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + residual
            x = self.extra_layernorms[i](x)

        x = torch.cat([x, attention_out], dim=2)
        x = self._concat_latent_code(x, encoder_out)
        if self.bottleneck_layer is not None:
            x = self.bottleneck_layer(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.share_input_output_embed:
            logits = F.linear(x, self.embed_tokens.weight)
        else:
            logits = F.linear(x, self.embed_out)

        if incremental_state is not None:
            # encoder projections can be reused at each incremental step
            state_outputs.extend([prev_states[-2], prev_states[-1]])
            utils.set_incremental_state(
                self, incremental_state, "cached_state", state_outputs
            )

        return logits, attention_weights

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        return int(1024)  # an arbitrary large number

    def _init_prev_states(self, encoder_out):
        """
        Initial (hidden, cell) values for LSTM layers are zero.

        For encoder-decoder attention, key and value are computed once from
        the encoder outputs and stay the same throughout decoding.
        """
        (encoder_x, encoder_padding_mask) = self._unpack_encoder_out(encoder_out)
        batch_size = torch.onnx.operators.shape_as_tensor(encoder_x)[1]

        if self.proj_encoder_layer is not None:
            encoder_x = self.proj_encoder_layer(encoder_x)

        states = []
        for _ in range(self.num_layers):
            hidden = self._init_hidden(encoder_out, batch_size).type_as(encoder_x)
            cell = torch.zeros([1, batch_size, self.lstm_units]).type_as(encoder_x)
            states.extend([hidden, cell])

        # (key, value) for encoder-decoder attention computed from encoder
        # output and remain the same throughout decoding
        key = self.attention.k_proj(encoder_x)
        value = self.attention.v_proj(encoder_x)

        # (key, value) kept in shape (bsz, num_heads, seq_len, head_dim)
        # to avoid repeated transpose operations
        seq_len, batch_size_int, _ = encoder_x.shape
        num_heads = self.attention.num_heads
        head_dim = self.attention.head_dim
        key = (
            key.view(seq_len, batch_size_int * num_heads, head_dim)
            .transpose(0, 1)
            .view(batch_size_int, num_heads, seq_len, head_dim)
        )
        value = (
            value.view(seq_len, batch_size_int * num_heads, head_dim)
            .transpose(0, 1)
            .view(batch_size_int, num_heads, seq_len, head_dim)
        )
        states.extend([key, value])

        return states

    def reorder_incremental_state(self, incremental_state, new_order):
        # parent reorders attention model
        super().reorder_incremental_state(incremental_state, new_order)

        cached_state = utils.get_incremental_state(
            self, incremental_state, "cached_state"
        )
        if cached_state is None:
            return

        # Last 2 elements of prev_states are encoder projections
        # used for ONNX export
        for i, state in enumerate(cached_state[:-2]):
            cached_state[i] = state.index_select(1, new_order)
        for i in [-2, -1]:
            cached_state[i] = cached_state[i].index_select(0, new_order)

        utils.set_incremental_state(
            self, incremental_state, "cached_state", cached_state
        )
Пример #21
0
class TransformerAANDecoderLayer(nn.Module):
    """Decoder layer block.
    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.
    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs.
            Default: ``False``
    """
    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.more_dropouts = args.decoder_aan_more_dropouts
        if args.decoder_attn_window_size <= 0:
            self.avg_attn = AverageAttention(self.embed_dim,
                                             dropout=args.attention_dropout)
        else:
            self.avg_attn = AverageWindowAttention(
                self.embed_dim,
                dropout=args.attention_dropout,
                window_size=args.decoder_attn_window_size,
            )
        # self.activation = getattr(args, "decoder_ffn_activation", "relu")
        self.aan_layer_norm = LayerNorm(self.embed_dim)
        if args.no_decoder_aan_ffn:
            self.aan_ffn = None
        else:
            aan_ffn_hidden_dim = (args.decoder_ffn_embed_dim
                                  if args.decoder_aan_ffn_use_embed_dim else
                                  args.decoder_ffn_embed_dim)
            self.aan_ffn = FeedForwardNetwork(
                self.embed_dim,
                aan_ffn_hidden_dim,
                self.embed_dim,
                num_layers=2,
                dropout=args.relu_dropout,
            )

        if args.no_decoder_aan_gating:
            self.aan_gating_fc = None
        else:
            self.aan_gating_fc = Linear(self.embed_dim * 2, self.embed_dim * 2)
        self.normalize_before = args.decoder_normalize_before

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=args.encoder_embed_dim,
                vdim=args.encoder_embed_dim,
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.ffn = FeedForwardNetwork(
            self.embed_dim,
            args.decoder_ffn_embed_dim,
            self.embed_dim,
            num_layers=2,
            dropout=args.relu_dropout,
        )

        self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out,
        encoder_padding_mask,
        incremental_state,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.
        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        if "residual" in self.more_dropouts:
            residual = F.dropout(residual,
                                 p=self.dropout,
                                 training=self.training)

        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.avg_attn._set_input_buffer(incremental_state, saved_state)
        x, _ = self.avg_attn(
            value=x,
            mask_future_timesteps=True,
            incremental_state=incremental_state,
            mask_trick=self.training,
        )
        if "after_avg" in self.more_dropouts:
            x = F.dropout(x, p=self.dropout, training=self.training)

        if self.aan_layer_norm is not None:
            x = self.maybe_layer_norm(self.aan_layer_norm, x, before=True)

        if self.aan_ffn is not None:
            x = self.aan_ffn(x)
            if "after_ffn" in self.more_dropouts:
                x = F.dropout(x, p=self.dropout, training=self.training)

        if self.aan_gating_fc is not None:
            i, f = self.aan_gating_fc(torch.cat([residual, x],
                                                dim=-1)).chunk(2, dim=-1)
            x = torch.sigmoid(f) * residual + torch.sigmoid(i) * x
            if "after_gating" in self.more_dropouts:
                x = F.dropout(x, p=self.dropout, training=self.training)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x

        if self.aan_layer_norm is not None:
            x = self.maybe_layer_norm(self.aan_layer_norm, x, after=True)

        attn = None
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.ffn(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)

        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn

    def extra_repr(self):
        return "dropout={}, more_dropouts={}".format(self.dropout,
                                                     self.more_dropouts)
Пример #22
0
class TransformerDecoderLayerPhase2(nn.Module):
    """Second phase of decoder layer block
    This layer will take the input from the ecoder and phirst pass decoder.
    papers.nips.cc/paper/6775-deliberation-networks-sequence-generation-beyond-one-pass-decoding.pdf
    """
    def __init__(
        self,
        args,
        no_encoder_decoder_attn=False,
        add_bias_kv=False,
        add_zero_attn=False,
    ):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=True,
        )
        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu"))
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        if no_encoder_decoder_attn:
            self.encoder_attn = None
            self.decoder_attn = None
            self.encoder_layer_norm = None
            self.decoder_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.decoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)
            self.decoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        decoder_out=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_encoder_attn_state=None,
        prev_decoder_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x_self_attention = self.maybe_layer_norm(self.self_attn_layer_norm,
                                                 x,
                                                 before=True)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x_self_attention, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x_self_attention = F.dropout(x_self_attention,
                                     p=self.dropout,
                                     training=self.training)
        x_self_attention = residual + x_self_attention
        x_self_attention = self.maybe_layer_norm(self.self_attn_layer_norm,
                                                 x_self_attention,
                                                 after=True)

        if self.encoder_attn is not None:
            residual = x
            x_encoder_attention = self.maybe_layer_norm(
                self.encoder_attn_layer_norm, x, before=True)
            if prev_encoder_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_encoder_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x_encoder_attention, attn = self.encoder_attn(
                query=x_encoder_attention,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x_encoder_attention = F.dropout(x_encoder_attention,
                                            p=self.dropout,
                                            training=self.training)
            x_encoder_attention = residual + x_encoder_attention
            x_encoder_attention = self.maybe_layer_norm(
                self.encoder_attn_layer_norm, x_encoder_attention, after=True)

        if self.decoder_attn is not None:
            residual = x
            x_decoder_attention = self.maybe_layer_norm(
                self.decoder_attn_layer_norm, x, before=True)
            if prev_decoder_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_decoder_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x_decoder_attention, attn = self.decoder_attn(
                query=x_decoder_attention,
                key=decoder_out,
                value=decoder_out,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x_decoder_attention = F.dropout(x_decoder_attention,
                                            p=self.dropout,
                                            training=self.training)
            x_decoder_attention = residual + x_decoder_attention
            x_decoder_attention = self.maybe_layer_norm(
                self.encoder_attn_layer_norm, x_decoder_attention, after=True)
        x = x_self_attention + x_encoder_attention + x_decoder_attention

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Пример #23
0
class AANDecoderLayer(nn.Module):
    """
    Based on https://arxiv.org/abs/1805.00631
    """
    def __init__(self, args):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, "cross_self_attention",
                                            False)

        self.avg_attn = AverageAttention(self.embed_dim,
                                         dropout=args.attention_dropout)

        # differently than original paper, we use a single gate
        self.aan_gating_fc = fairseq_transformer.Linear(
            self.embed_dim * 2, self.embed_dim)

        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, "activation_fn", "relu"))
        self.activation_dropout = getattr(args, "activation_dropout", 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, "relu_dropout", 0)
        self.normalize_before = args.decoder_normalize_before

        # use layerNorm rather than FusedLayerNorm for exporting.
        # char_inputs can be used to determint this.
        # TODO  remove this once we update apex with the fix
        export = getattr(args, "char_inputs", False)
        self.avg_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.encoder_attn = MultiheadAttention(
            self.embed_dim,
            args.decoder_attention_heads,
            kdim=getattr(args, "encoder_embed_dim", None),
            vdim=getattr(args, "encoder_embed_dim", None),
            dropout=args.attention_dropout,
            encoder_decoder_attention=True,
        )
        self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.fc1 = fairseq_transformer.Linear(self.embed_dim,
                                              args.decoder_ffn_embed_dim)
        self.fc2 = fairseq_transformer.Linear(args.decoder_ffn_embed_dim,
                                              self.embed_dim)

        self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
        need_attn=False,
        need_head_weights=False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary
                ByteTensor of shape `(batch, src_len)` where padding
                elements are indicated by ``1``.
            need_attn (bool, optional): return attention weights
            need_head_weights (bool, optional): return attention weights
                for each head (default: return average over heads).

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`

        The following are used for export tracing:
            prev_self_attn_state: [prev_sum, prev_pos]
                assumes AverageAttention without mask trick
            prev_attn_state: [prev_key, prev_value]
        """
        if need_head_weights:
            need_attn = True

        residual = x
        x = self.maybe_layer_norm(self.avg_attn_layer_norm, x, before=True)

        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_sum, prev_pos = prev_self_attn_state
            # (batch, embed) -> (seq, batch, embed)
            prev_sum = prev_sum.unsqueeze(0)
            saved_state = {"prev_sum": prev_sum, "prev_pos": prev_pos}
            self.avg_attn._set_input_buffer(incremental_state, saved_state)

        x, _ = self.avg_attn(
            value=x,
            mask_future_timesteps=True,
            incremental_state=incremental_state,
            mask_trick=self.training,
        )

        # differently than original paper, we use a single gate
        gate = torch.sigmoid(
            self.aan_gating_fc(torch.cat([residual, x], dim=-1)))
        x = gate * x + (1 - gate) * residual

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.avg_attn_layer_norm, x, after=True)

        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      before=True)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)

        if self.onnx_trace and incremental_state is not None:
            saved_state = self.avg_attn._get_input_buffer(incremental_state)
            # remove sequence axis for export
            prev_sum = saved_state["prev_sum"]
            # (seq, batch, embed) -> (batch, embed)
            prev_sum = prev_sum.squeeze(0)
            prev_pos = saved_state["prev_pos"]
            self_attn_state = prev_sum, prev_pos
            return x, attn, self_attn_state

        return x, attn, None

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
class TransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def __init__(self,
                 args,
                 no_encoder_attn=False,
                 add_bias_kv=False,
                 add_zero_attn=False,
                 LayerNum=None):
        super().__init__()

        global tmp_file

        self.args = args
        if not hasattr(self.args, 'mixed_precision'):
            self.args.mixed_precision = False
        if not hasattr(self.args, 'plot_variance'):
            self.args.plot_variance = False
        if not hasattr(self.args, 'plot_gradient'):
            self.args.plot_gradient = False

        self.normalize_before = args.decoder_normalize_before
        self.embed_dim = args.decoder_embed_dim
        self.cross_self_attention = getattr(args, 'cross_self_attention',
                                            False)

        self.layer_num = LayerNum
        if 'adaptive' in args.init_type:
            assert not self.normalize_before

            self.self_attn = MultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=not self.cross_self_attention)

            assert not no_encoder_attn
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True)

            self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
            self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

            if 'adaptive-profiling' == args.init_type:
                if not tmp_file:
                    tmp_file = open('profile.ratio.init', 'w')
                self.self_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.encoder_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
            else:
                if not tmp_file:
                    tmp_file = open('profile.ratio.init', 'r')

                layer_iter, next_value = [
                    float(tup) for tup in tmp_file.readline().split()
                ]
                print('layer_num: {}, layer_iter: {}'.format(
                    self.layer_num, layer_iter))
                assert layer_iter == 3 * self.layer_num + 1
                print('decoder self ratio: {}'.format(next_value))
                self.self_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.self_ratio_change.data.fill_(next_value)

                layer_iter, next_value = [
                    float(tup) for tup in tmp_file.readline().split()
                ]
                print('layer_num: {}, layer_iter: {}'.format(
                    self.layer_num, layer_iter))
                assert layer_iter == 3 * self.layer_num + 2
                print('decoder en ratio: {}'.format(next_value))
                self.encoder_ratio_change = nn.Parameter(
                    torch.ones(self.embed_dim))
                self.encoder_ratio_change.data.fill_(next_value)

                layer_iter, next_value = [
                    float(tup) for tup in tmp_file.readline().split()
                ]
                print('layer_num: {}, layer_iter: {}'.format(
                    self.layer_num, layer_iter))
                assert layer_iter == 3 * self.layer_num + 3
                print('decoder ffn ratio: {}'.format(next_value))
                self.fc_ratio_change = nn.Parameter(torch.ones(self.embed_dim))
                self.fc_ratio_change.data.fill_(next_value)

            export = getattr(args, 'char_inputs', False)
            self.self_attn_layer_norm = LayerNorm(self.embed_dim,
                                                  export=export)
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                     export=export)
            self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
        else:
            self.self_attn = MultiheadAttention(
                embed_dim=self.embed_dim,
                num_heads=args.decoder_attention_heads,
                dropout=args.attention_dropout,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                self_attention=not self.cross_self_attention)

            assert not no_encoder_attn
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, 'encoder_embed_dim', None),
                vdim=getattr(args, 'encoder_embed_dim', None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True)

            self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
            self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
            if args.init_type == 'looklinear':
                self.fc1.weight.data[int(args.decoder_ffn_embed_dim /
                                         2):, :] = -self.fc1.weight.data[
                                             0:int(args.decoder_ffn_embed_dim /
                                                   2), :]
                self.fc2.weight.data[:,
                                     int(args.decoder_ffn_embed_dim /
                                         2):] = -self.fc2.weight.data[:, 0:int(
                                             args.decoder_ffn_embed_dim / 2)]

            export = getattr(args, 'char_inputs', False)

            if args.init_type != 'rezero':
                self.self_attn_layer_norm = LayerNorm(self.embed_dim,
                                                      export=export)
                if no_encoder_attn:
                    self.encoder_attn = None
                    self.encoder_attn_layer_norm = None
                else:
                    self.encoder_attn_layer_norm = LayerNorm(self.embed_dim,
                                                             export=export)
                self.final_layer_norm = LayerNorm(self.embed_dim,
                                                  export=export)
            else:
                self.self_attn_layer_norm = None
                self.encoder_attn_layer_norm = None
                self.final_layer_norm = None

            if 'rezero' in args.init_type:
                self.rezero_weight = nn.Parameter(torch.Tensor([0]))
            else:
                assert args.init_type == 'default'
                self.rezero_weight = None

        self.dropout = args.dropout
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu'))
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            self.activation_dropout = getattr(args, 'relu_dropout', 0)

        self.need_attn = True

        self.onnx_trace = False

        if args.fp16:
            self.in_type = torch.half
        else:
            self.in_type = torch.float

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
        need_attn=False,
        need_head_weights=False,
    ):
        not_initialized = ('adaptive-profiling' == self.args.init_type) and (
            1.0 == self.self_ratio_change.min()) and self.training

        if need_head_weights:
            need_attn = True

        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)

        if self.args.mixed_precision:
            x = x.type(self.in_type)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        if self.cross_self_attention and not (
                incremental_state is not None and "prev_key"
                in self.self_attn._get_input_buffer(incremental_state)):
            if self_attn_mask is not None:
                self_attn_mask = torch.cat((x.new(
                    x.size(0), encoder_out.size(0)).zero_(), self_attn_mask),
                                           dim=1)
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    encoder_padding_mask = self_attn_padding_mask.new(
                        encoder_out.size(1), encoder_out.size(0)).zero_()
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1)
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x
        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.args.mixed_precision:
            x = x.float()
        if 'adaptive' in self.args.init_type:
            if not_initialized:
                global decoder_ratio, tmp_file
                tmp_layer_ind = self.layer_num * 3 + 1
                tmp_ratio = decoder_ratio
                tmp_file.write('{} {}\n'.format(tmp_layer_ind, tmp_ratio))
                self.self_ratio_change.data.fill_(tmp_ratio)
                print('decoder self attn ratio: {}'.format(tmp_ratio))
                input_std = np.var((residual * self.self_ratio_change).clone(
                ).cpu().float().data.contiguous().view(-1).numpy())
                output_std = np.var(
                    x.clone().cpu().float().data.contiguous().view(-1).numpy())
                decoder_ratio = np.sqrt(input_std + output_std)
            x0 = x + residual * self.self_ratio_change
        elif self.rezero_weight is not None:
            x0 = residual + self.rezero_weight * x
        else:
            x0 = residual + x
        x0 = self.maybe_layer_norm(self.self_attn_layer_norm, x0, after=True)
        if self.args.plot_gradient:
            x0.register_hook(lambda grad: print('{} decoder s-att: {}'.format(
                self.layer_num,
                grad.norm().item())))
        x = x0
        if self.encoder_attn is not None:
            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x0,
                                      before=True)

            if self.args.mixed_precision:
                x = x.type(self.in_type)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state[:2]
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn
                or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = F.dropout(x, p=self.dropout, training=self.training)

            if self.args.mixed_precision:
                x = x.float()
            if 'adaptive' in self.args.init_type:
                if not_initialized:
                    tmp_layer_ind = self.layer_num * 3 + 2
                    tmp_ratio = decoder_ratio
                    tmp_file.write('{} {}\n'.format(tmp_layer_ind, tmp_ratio))
                    self.encoder_ratio_change.data.fill_(tmp_ratio)
                    print('decoder encoder attn ratio: {}'.format(tmp_ratio))
                    input_std = np.var(
                        (residual * self.encoder_ratio_change).clone().cpu(
                        ).float().data.contiguous().view(-1).numpy())
                    output_std = np.var(
                        x.clone().cpu().float().data.contiguous().view(
                            -1).numpy())
                    decoder_ratio = np.sqrt(input_std + output_std)
                x1 = x + residual * self.encoder_ratio_change
            elif self.rezero_weight is not None:
                x1 = residual + self.rezero_weight * x
            else:
                x1 = residual + x
            x1 = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                       x1,
                                       after=True)
            if self.args.plot_gradient:
                x1.register_hook(lambda grad: print(
                    '{} decoder e-att: {}'.format(self.layer_num,
                                                  grad.norm().item())))
            x = x1
        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)

        if self.args.mixed_precision:
            x = x.type(self.in_type)
        bx = self.fc1(x)
        hx = self.activation_fn(bx)
        x = F.dropout(hx, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.args.mixed_precision:
            x = x.float()
        if 'adaptive' in self.args.init_type:
            if not_initialized:
                tmp_layer_ind = self.layer_num * 3 + 3
                tmp_ratio = decoder_ratio
                tmp_file.write('{} {}\n'.format(tmp_layer_ind, tmp_ratio))
                self.fc_ratio_change.data.fill_(tmp_ratio)
                print('decoder ffn ratio: {}'.format(tmp_ratio))
                input_var = np.var((residual * self.fc_ratio_change).clone(
                ).cpu().float().data.contiguous().view(-1).numpy())
                output_var = np.var(
                    x.clone().cpu().float().data.contiguous().view(-1).numpy())
                decoder_ratio = np.sqrt(input_var + output_var)
            x2 = x + residual * self.fc_ratio_change
        elif self.rezero_weight is not None:
            x2 = residual + self.rezero_weight * x
        else:
            x2 = residual + x
        x2 = self.maybe_layer_norm(self.final_layer_norm, x2, after=True)
        if self.args.plot_gradient:
            x2.register_hook(lambda grad: print('{} decoder ffn: {}'.format(
                self.layer_num,
                grad.norm().item())))

        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            if self_attn_padding_mask is not None:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"], saved_state["prev_key_padding_mask"]
            else:
                self_attn_state = saved_state["prev_key"], saved_state[
                    "prev_value"]
            return x2, attn, self_attn_state

        return x2, attn

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        if self.args.init_type == 'rezero':
            return x

        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn
Пример #25
0
class FixupTransformerDecoderLayer(nn.Module):
    """Decoder layer block.

    In the original paper each operation (multi-head attention, encoder
    attention or FFN) is postprocessed with: `dropout -> add residual ->
    layernorm`. In the tensor2tensor code they suggest that learning is more
    robust when preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.decoder_normalize_before* to ``True``.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """
    def fixup_initialization(self, args):
        temp_state_dic = {}
        de_layers = args.decoder_layers

        if args.Tfixup:
            for name, param in self.named_parameters():
                if name in [
                        "fc1.weight",
                        "fc2.weight",
                        "self_attn.out_proj.weight",
                        "encoder_attn.out_proj.weight",
                ]:
                    temp_state_dic[name] = (9 * de_layers)**(-1. / 4.) * param
                elif name in [
                        "self_attn.v_proj.weight",
                        "encoder_attn.v_proj.weight",
                ]:
                    temp_state_dic[name] = (9 * de_layers)**(-1. /
                                                             4.) * (param *
                                                                    (2**0.5))

        for name in self.state_dict():
            if name not in temp_state_dic:
                temp_state_dic[name] = self.state_dict()[name]
        self.load_state_dict(temp_state_dic)

    def __init__(self, args, no_encoder_attn=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        if args.max_relative_length == -1:
            self.self_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
        else:
            self.self_attn = RelativeMultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                args.max_relative_length,
                dropout=args.attention_dropout,
                k_only=args.k_only,
            )

        self.dropout = args.dropout
        self.relu_dropout = args.relu_dropout
        self.normalize_before = args.decoder_normalize_before

        if no_encoder_attn:
            self.encoder_attn = None
            self.encoder_attn_layer_norm = None
        else:
            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                dropout=args.attention_dropout,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)

        self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)

        self.noLN = args.dont_use_layernorm
        if not self.noLN:
            self.self_attn_layer_norm = LayerNorm(self.embed_dim)
            self.final_layer_norm = LayerNorm(self.embed_dim)
        self.need_attn = True

        self.onnx_trace = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(self,
                x,
                encoder_out,
                encoder_padding_mask,
                incremental_state,
                prev_self_attn_state=None,
                prev_attn_state=None,
                self_attn_mask=None,
                self_attn_padding_mask=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        if self.normalize_before and not self.noLN:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before and not self.noLN:
            x = self.self_attn_layer_norm(x)

        attn = None
        if self.encoder_attn is not None:
            residual = x
            if self.normalize_before and not self.noLN:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn._set_input_buffer(incremental_state,
                                                    saved_state)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            if not self.normalize_before and not self.noLN:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before and not self.noLN:
            x = self.final_layer_norm(x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        if not self.normalize_before and not self.noLN:
            x = self.final_layer_norm(x)
        if self.onnx_trace:
            saved_state = self.self_attn._get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn