Пример #1
0
class HybridRNNDecoder(FairseqIncrementalDecoder):
    """
    Decoder with general structure of Chen et al., The Best of Both Worlds:
    Combining Recent Advances in Neural Machine Translation, 2018.
    https://arxiv.org/abs/1804.09849
    """
    def _init_dims(self, args, src_dict, dst_dict, embed_tokens):
        self.dropout = args.dropout

        embed_dim = embed_tokens.embedding_dim
        self.embed_tokens = embed_tokens

        self.lstm_units = args.decoder_lstm_units
        self.num_layers = args.decoder_layers
        self.initial_input_dim = embed_dim

        self.encoder_output_dim = args.encoder_embed_dim
        if args.decoder_reduced_attention_dim is None:
            self.attention_dim = self.encoder_output_dim
        else:
            self.attention_dim = args.decoder_reduced_attention_dim
        self.input_dim = self.lstm_units + self.attention_dim

        self.num_attention_heads = args.decoder_attention_heads
        self.bottleneck_dim = args.decoder_out_embed_dim

    def _init_components(self, args, src_dict, dst_dict, embed_tokens):
        self.initial_rnn_layer = nn.LSTM(input_size=self.initial_input_dim,
                                         hidden_size=self.lstm_units)

        self.proj_encoder_layer = None
        if self.attention_dim != self.encoder_output_dim:
            self.proj_encoder_layer = fairseq_transformer.Linear(
                self.encoder_output_dim, self.attention_dim)

        self.proj_layer = None
        if self.lstm_units != self.attention_dim:
            self.proj_layer = fairseq_transformer.Linear(
                self.lstm_units, self.attention_dim)

        self.attention = MultiheadAttention(
            self.attention_dim,
            self.num_attention_heads,
            dropout=args.attention_dropout,
            encoder_decoder_attention=True,
        )

        self.extra_rnn_layers = nn.ModuleList([])
        for _ in range(self.num_layers - 1):
            self.extra_rnn_layers.append(
                nn.LSTM(input_size=self.input_dim,
                        hidden_size=self.lstm_units))

        self.bottleneck_layer = None
        if self.bottleneck_dim is not None:
            self.out_embed_dim = self.bottleneck_dim
            self.bottleneck_layer = fairseq_transformer.Linear(
                self.input_dim, self.out_embed_dim)
        else:
            self.out_embed_dim = self.input_dim

        self.embed_out = nn.Parameter(
            torch.Tensor(len(dst_dict), self.out_embed_dim))
        nn.init.normal_(self.embed_out, mean=0, std=self.out_embed_dim**-0.5)

        self.vocab_reduction_module = None
        if args.vocab_reduction_params:
            self.vocab_reduction_module = vocab_reduction.VocabReduction(
                src_dict,
                dst_dict,
                args.vocab_reduction_params,
                fp16=args.fp16)

        self.onnx_trace = False

    def __init__(self, args, src_dict, dst_dict, embed_tokens):
        super().__init__(dst_dict)
        self._init_dims(args, src_dict, dst_dict, embed_tokens)
        self._init_components(args, src_dict, dst_dict, embed_tokens)

    # Enable dependency injection by subclasses
    def _unpack_encoder_out(self, encoder_out):
        """ Allow taking encoder_out from different architecture which
        may have different formats.
        """
        return encoder_out

    def _init_hidden(self, encoder_out, batch_size):
        """ Initialize with latent code if available otherwise zeros."""
        return torch.zeros([1, batch_size, self.lstm_units])

    def _concat_latent_code(self, x, encoder_out):
        """ Concat latent code, if available in encoder_out, which is the
        case in subclass.
        """
        return x

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def _embed_prev_outputs(self, prev_output_tokens, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        return x, prev_output_tokens

    def forward(
        self,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        x, prev_output_tokens = self._embed_prev_outputs(
            prev_output_tokens=prev_output_tokens,
            incremental_state=incremental_state)
        return self._forward_given_embeddings(
            embed_out=x,
            prev_output_tokens=prev_output_tokens,
            encoder_out=encoder_out,
            incremental_state=incremental_state,
            possible_translation_tokens=possible_translation_tokens,
            timestep=timestep,
        )

    def _forward_given_embeddings(
        self,
        embed_out,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        x = embed_out
        (encoder_x, src_tokens,
         encoder_padding_mask) = self._unpack_encoder_out(encoder_out)
        bsz, seqlen = prev_output_tokens.size()

        state_outputs = []
        if incremental_state is not None:
            prev_states = utils.get_incremental_state(self, incremental_state,
                                                      "cached_state")
            if prev_states is None:
                prev_states = self._init_prev_states(encoder_out)

            # final 2 states of list are projected key and value
            saved_state = {
                "prev_key": prev_states[-2],
                "prev_value": prev_states[-1]
            }
            self.attention._set_input_buffer(incremental_state, saved_state)

        if incremental_state is not None:
            # first num_layers pairs of states are (prev_hidden, prev_cell)
            # for each layer
            h_prev = prev_states[0]
            c_prev = prev_states[1]
        else:
            h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
            c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

        x = self._concat_latent_code(x, encoder_out)
        x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev))
        if incremental_state is not None:
            state_outputs.extend([h_next, c_next])

        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.proj_encoder_layer is not None:
            encoder_x = self.proj_encoder_layer(encoder_x)

        attention_in = x
        if self.proj_layer is not None:
            attention_in = self.proj_layer(x)

        attention_out, attention_weights = self.attention(
            query=attention_in,
            key=encoder_x,
            value=encoder_x,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            need_weights=(not self.training),
        )

        for i, layer in enumerate(self.extra_rnn_layers):
            residual = x
            rnn_input = torch.cat([x, attention_out], dim=2)
            rnn_input = self._concat_latent_code(rnn_input, encoder_out)

            if incremental_state is not None:
                # first num_layers pairs of states are (prev_hidden, prev_cell)
                # for each layer
                h_prev = prev_states[2 * i + 2]
                c_prev = prev_states[2 * i + 3]
            else:
                h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
                c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

            x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev))
            if incremental_state is not None:
                state_outputs.extend([h_next, c_next])
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + residual

        x = torch.cat([x, attention_out], dim=2)
        x = self._concat_latent_code(x, encoder_out)
        if self.bottleneck_layer is not None:
            x = self.bottleneck_layer(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if (self.vocab_reduction_module is not None
                and possible_translation_tokens is None):
            decoder_input_tokens = prev_output_tokens.contiguous()
            possible_translation_tokens = self.vocab_reduction_module(
                src_tokens, decoder_input_tokens=decoder_input_tokens)

        output_weights = self.embed_out
        if possible_translation_tokens is not None:
            output_weights = output_weights.index_select(
                dim=0, index=possible_translation_tokens)

        logits = F.linear(x, output_weights)

        if incremental_state is not None:
            # encoder projections can be reused at each incremental step
            state_outputs.extend([prev_states[-2], prev_states[-1]])
            utils.set_incremental_state(self, incremental_state,
                                        "cached_state", state_outputs)

        return logits, attention_weights, possible_translation_tokens

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        return int(1e5)  # an arbitrary large number

    def _init_prev_states(self, encoder_out):
        """
        Initial (hidden, cell) values for LSTM layers are zero.

        For encoder-decoder attention, key and value are computed once from
        the encoder outputs and stay the same throughout decoding.
        """
        (encoder_x, src_tokens,
         encoder_padding_mask) = self._unpack_encoder_out(encoder_out)
        batch_size = torch.onnx.operators.shape_as_tensor(encoder_x)[1]

        if self.proj_encoder_layer is not None:
            encoder_x = self.proj_encoder_layer(encoder_x)

        states = []
        for _ in range(self.num_layers):
            hidden = self._init_hidden(encoder_out,
                                       batch_size).type_as(encoder_x)
            cell = torch.zeros([1, batch_size,
                                self.lstm_units]).type_as(encoder_x)
            states.extend([hidden, cell])

        # (key, value) for encoder-decoder attention computed from encoder
        # output and remain the same throughout decoding
        key = self.attention.k_proj(encoder_x)
        value = self.attention.v_proj(encoder_x)

        # (key, value) kept in shape (bsz, num_heads, seq_len, head_dim)
        # to avoid repeated transpose operations
        seq_len, batch_size_int, _ = encoder_x.shape
        num_heads = self.attention.num_heads
        head_dim = self.attention.head_dim
        key = (key.view(seq_len, batch_size_int * num_heads,
                        head_dim).transpose(0,
                                            1).view(batch_size_int, num_heads,
                                                    seq_len, head_dim))
        value = (value.view(seq_len, batch_size_int * num_heads,
                            head_dim).transpose(0, 1).view(
                                batch_size_int, num_heads, seq_len, head_dim))
        states.extend([key, value])

        return states

    def reorder_incremental_state(self, incremental_state, new_order):
        # parent reorders attention model
        super().reorder_incremental_state(incremental_state, new_order)

        cached_state = utils.get_incremental_state(self, incremental_state,
                                                   "cached_state")
        if cached_state is None:
            return

        # Last 2 elements of prev_states are encoder projections
        # used for ONNX export
        for i, state in enumerate(cached_state[:-2]):
            cached_state[i] = state.index_select(1, new_order)

        utils.set_incremental_state(self, incremental_state, "cached_state",
                                    cached_state)
Пример #2
0
class HybridRNNDecoder(FairseqIncrementalDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
        super().__init__(dictionary)
        self.register_buffer('version', torch.Tensor([3]))

        self.dropout = args.dropout
        self.share_input_output_embed = args.share_decoder_input_output_embed

        embed_dim = embed_tokens.embedding_dim
        self.embed_tokens = embed_tokens

        self.lstm_units = args.decoder_lstm_units
        self.num_layers = args.decoder_layers
        self.initial_input_dim = embed_dim

        self.encoder_output_dim = args.encoder_embed_dim
        if args.decoder_reduced_attention_dim is None:
            self.attention_dim = self.encoder_output_dim
        else:
            self.attention_dim = args.decoder_reduced_attention_dim
        self.input_dim = self.lstm_units + self.attention_dim

        self.num_attention_heads = args.decoder_attention_heads
        self.bottleneck_dim = args.decoder_out_embed_dim


        self.initial_rnn_layer = nn.LSTM(
            input_size=self.initial_input_dim, hidden_size=self.lstm_units
        )
        self.initial_layernorm = LayerNorm(self.lstm_units)

        self.proj_encoder_layer = None
        if self.attention_dim != self.encoder_output_dim:
            self.proj_encoder_layer = Linear(
                self.encoder_output_dim, self.attention_dim
            )

        self.proj_layer = None
        if self.lstm_units != self.attention_dim:
            self.proj_layer = Linear(
                self.lstm_units, self.attention_dim
            )

        self.attention = MultiheadAttention(
            self.attention_dim,
            self.num_attention_heads,
            dropout=args.attention_dropout,
            encoder_decoder_attention=True,
        )

        self.extra_rnn_layers = nn.ModuleList([])
        self.extra_layernorms = nn.ModuleList([])
        for _ in range(self.num_layers - 1):
            self.extra_rnn_layers.append(
                nn.LSTM(input_size=self.input_dim, hidden_size=self.lstm_units)
            )
            self.extra_layernorms.append(
                LayerNorm(self.lstm_units)
            )

        self.bottleneck_layer = None
        if self.bottleneck_dim is not None:
            self.out_embed_dim = self.bottleneck_dim
            self.bottleneck_layer = Linear(
                self.input_dim, self.out_embed_dim
            )
        else:
            self.out_embed_dim = self.input_dim

        if not self.share_input_output_embed:
            self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.out_embed_dim))
            nn.init.normal_(self.embed_out, mean=0, std=self.out_embed_dim ** -0.5)
        else:
            assert self.bottleneck_dim == args.decoder_embed_dim, (self.bottleneck_dim, args.decoder_embed_dim)

    def _unpack_encoder_out(self, encoder_out):
        """ Allow taking encoder_out from different architecture which
        may have different formats.
        """
        # return encoder_out['encoder_out'], encoder_out['encoder_padding_mask']
        return encoder_out.encoder_out, encoder_out.encoder_padding_mask

    def _init_hidden(self, encoder_out, batch_size):
        """ Initialize with latent code if available otherwise zeros."""
        return torch.zeros([1, batch_size, self.lstm_units])

    def _concat_latent_code(self, x, encoder_out):
        """ Concat latent code, if available in encoder_out, which is the
        case in subclass.
        """
        return x

    def _embed_prev_outputs(self, prev_output_tokens, incremental_state=None):
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        x = self.embed_tokens(prev_output_tokens)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        return x, prev_output_tokens

    def forward(
        self,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        x, prev_output_tokens = self._embed_prev_outputs(
            prev_output_tokens=prev_output_tokens, incremental_state=incremental_state
        )
        return self._forward_given_embeddings(
            embed_out=x,
            prev_output_tokens=prev_output_tokens,
            encoder_out=encoder_out,
            incremental_state=incremental_state,
            possible_translation_tokens=possible_translation_tokens,
            timestep=timestep,
        )

    def _forward_given_embeddings(
        self,
        embed_out,
        prev_output_tokens,
        encoder_out,
        incremental_state=None,
        possible_translation_tokens=None,
        timestep=None,
    ):
        x = embed_out
        (encoder_x, encoder_padding_mask) = self._unpack_encoder_out(encoder_out)
        bsz, seqlen = prev_output_tokens.size()

        state_outputs = []
        if incremental_state is not None:
            prev_states = utils.get_incremental_state(
                self, incremental_state, "cached_state"
            )
            if prev_states is None:
                prev_states = self._init_prev_states(encoder_out)

            # final 2 states of list are projected key and value
            saved_state = {"prev_key": prev_states[-2], "prev_value": prev_states[-1]}
            self.attention._set_input_buffer(incremental_state, saved_state)

        if incremental_state is not None:
            # first num_layers pairs of states are (prev_hidden, prev_cell)
            # for each layer
            h_prev = prev_states[0]
            c_prev = prev_states[1]
        else:
            h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
            c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

        x = self._concat_latent_code(x, encoder_out)
        x, (h_next, c_next) = self.initial_rnn_layer(x, (h_prev, c_prev))
        x = self.initial_layernorm(x)
        if incremental_state is not None:
            state_outputs.extend([h_next, c_next])

        x = F.dropout(x, p=self.dropout, training=self.training)

        if self.proj_encoder_layer is not None:
            encoder_x = self.proj_encoder_layer(encoder_x)

        attention_in = x
        if self.proj_layer is not None:
            attention_in = self.proj_layer(x)

        attention_out, attention_weights = self.attention(
            query=attention_in,
            key=encoder_x,
            value=encoder_x,
            key_padding_mask=encoder_padding_mask,
            incremental_state=incremental_state,
            static_kv=True,
            need_weights=(not self.training),
        )

        for i, layer in enumerate(self.extra_rnn_layers):
            residual = x
            rnn_input = torch.cat([x, attention_out], dim=2)
            rnn_input = self._concat_latent_code(rnn_input, encoder_out)

            if incremental_state is not None:
                # first num_layers pairs of states are (prev_hidden, prev_cell)
                # for each layer
                h_prev = prev_states[2 * i + 2]
                c_prev = prev_states[2 * i + 3]
            else:
                h_prev = self._init_hidden(encoder_out, bsz).type_as(x)
                c_prev = torch.zeros([1, bsz, self.lstm_units]).type_as(x)

            x, (h_next, c_next) = layer(rnn_input, (h_prev, c_prev))
            if incremental_state is not None:
                state_outputs.extend([h_next, c_next])
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + residual
            x = self.extra_layernorms[i](x)

        x = torch.cat([x, attention_out], dim=2)
        x = self._concat_latent_code(x, encoder_out)
        if self.bottleneck_layer is not None:
            x = self.bottleneck_layer(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.share_input_output_embed:
            logits = F.linear(x, self.embed_tokens.weight)
        else:
            logits = F.linear(x, self.embed_out)

        if incremental_state is not None:
            # encoder projections can be reused at each incremental step
            state_outputs.extend([prev_states[-2], prev_states[-1]])
            utils.set_incremental_state(
                self, incremental_state, "cached_state", state_outputs
            )

        return logits, attention_weights

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        return int(1024)  # an arbitrary large number

    def _init_prev_states(self, encoder_out):
        """
        Initial (hidden, cell) values for LSTM layers are zero.

        For encoder-decoder attention, key and value are computed once from
        the encoder outputs and stay the same throughout decoding.
        """
        (encoder_x, encoder_padding_mask) = self._unpack_encoder_out(encoder_out)
        batch_size = torch.onnx.operators.shape_as_tensor(encoder_x)[1]

        if self.proj_encoder_layer is not None:
            encoder_x = self.proj_encoder_layer(encoder_x)

        states = []
        for _ in range(self.num_layers):
            hidden = self._init_hidden(encoder_out, batch_size).type_as(encoder_x)
            cell = torch.zeros([1, batch_size, self.lstm_units]).type_as(encoder_x)
            states.extend([hidden, cell])

        # (key, value) for encoder-decoder attention computed from encoder
        # output and remain the same throughout decoding
        key = self.attention.k_proj(encoder_x)
        value = self.attention.v_proj(encoder_x)

        # (key, value) kept in shape (bsz, num_heads, seq_len, head_dim)
        # to avoid repeated transpose operations
        seq_len, batch_size_int, _ = encoder_x.shape
        num_heads = self.attention.num_heads
        head_dim = self.attention.head_dim
        key = (
            key.view(seq_len, batch_size_int * num_heads, head_dim)
            .transpose(0, 1)
            .view(batch_size_int, num_heads, seq_len, head_dim)
        )
        value = (
            value.view(seq_len, batch_size_int * num_heads, head_dim)
            .transpose(0, 1)
            .view(batch_size_int, num_heads, seq_len, head_dim)
        )
        states.extend([key, value])

        return states

    def reorder_incremental_state(self, incremental_state, new_order):
        # parent reorders attention model
        super().reorder_incremental_state(incremental_state, new_order)

        cached_state = utils.get_incremental_state(
            self, incremental_state, "cached_state"
        )
        if cached_state is None:
            return

        # Last 2 elements of prev_states are encoder projections
        # used for ONNX export
        for i, state in enumerate(cached_state[:-2]):
            cached_state[i] = state.index_select(1, new_order)
        for i in [-2, -1]:
            cached_state[i] = cached_state[i].index_select(0, new_order)

        utils.set_incremental_state(
            self, incremental_state, "cached_state", cached_state
        )