Esempio n. 1
0
 def forward(self, input_states, ctx_states, mask_input, mask_ctx,
             **kwargs):
     """ return shape 'seq_len, batch_size, depth '"""
     if self.pos_init:
         pos_encoding = positional_encodings_like(input_states)
         output = input_states + pos_encoding
     else:
         pos_encoding = None
         output = input_states
     self_padding_mask = (1 - mask_input).byte()
     output = self.self_encoder(x=output, padding_mask=self_padding_mask)
     if self.pos_slf_attn is not None:
         if pos_encoding is None:
             pos_encoding = positional_encodings_like(output)
         output = self.pos_slf_attn(x=pos_encoding,
                                    value=output,
                                    padding_mask=self_padding_mask)
     inter_padding_mask = (1 - mask_ctx).byte()
     output = self.cond_attn(output,
                             ctx_states,
                             ctx_states,
                             padding_mask=inter_padding_mask,
                             static_kv=True)
     output = self.feedforward(output)
     return output
Esempio n. 2
0
    def extract_features(self, prev_output_tokens, encoder_out=None, **unused):
        """
        Similar to *forward* but only return features.

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
                    length-predict module output
                    position-search module output
                    position-predict module output
                    decoder module output
        """
        # embed positions
        inputs_dict = self._bridging(encoder_out=encoder_out,
                                     prev_output_tokens=prev_output_tokens)
        x = inputs_dict['inputs']
        if self.project_in_dim is not None:
            x = self.project_in_dim(x)
        positions = positional_encodings_like(x)

        if positions is not None:
            x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        # x = x.transpose(0, 1)
        # attn = None
        # inner_states = [x]
        # decoder layers
        # self_attn_masks = self._buffered_nat_mask(x)
        # self_attn_padding_mask = (1 - inputs_dict[LengthPredictorBridge.DECODE_MASK_KEY]).byte()
        # for layer in self.decoder_layers:
        #     x, attn = layer(
        #         x,
        #         encoder_out=encoder_out['encoder_out'] if encoder_out is not None else None,
        #         encoder_padding_mask=encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
        #         self_attn_mask=self_attn_masks,
        #         self_attn_padding_mask=self_attn_padding_mask
        #     )
        #     inner_states.append(x)
        #
        # if self.normalize:
        #     x = self.layer_norm(x)
        #
        # # T x B x C -> B x T x C
        # x = x.transpose(0, 1)
        #
        # if self.project_out_dim is not None:
        #     x = self.project_out_dim(x)
        #
        # inputs_dict['attn'] = attn
        # inputs_dict['inner_states'] = inner_states
        x, inputs_dict = self._decoding(
            x,
            encoder_out,
            inputs_dict,
        )
        return x, inputs_dict
Esempio n. 3
0
    def forward(self, src_tokens, src_lengths, **unused):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`

        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        if self.share_embed:
            x = F.embedding(src_tokens, self.embed_tokens.weight * self.embed_scale)
        else:
            x = self.embed_tokens(src_tokens) * self.embed_scale
        # x = self.embed_scale * self.embed_tokens(src_tokens)
        # if self.embed_positions is not None:
        #     x += self.embed_positions(src_tokens)
        x += positional_encodings_like(x)
        encoder_history = [x]
        # x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.dropout(x)
        # B x T x C -> T x B x C

        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)
            encoder_history.append(x.transpose(0, 1))

        if self.normalize:
            x = self.layer_norm(x)

        return {
            'encoder_history': encoder_history,  # List<B X T X C>
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
        }
Esempio n. 4
0
 def forward(self, input_states, ctx_states, mask_input, mask_ctx,
             **kwargs):
     """ return shape 'seq_len, batch_size, depth '"""
     output = self.encode(input_states, mask_input)
     self_padding_mask = (1 - mask_input).byte()
     if self.pos_slf_attn is not None:
         pos_encoding = positional_encodings_like(output)
         output = self.pos_slf_attn(x=pos_encoding,
                                    value=output,
                                    padding_mask=self_padding_mask)
     key_padding_mask = (1 - mask_ctx).byte()
     output = self.cond_attn(output, ctx_states, ctx_states,
                             key_padding_mask)
     output = self.feedforward(output)
     return output
Esempio n. 5
0
    def forward(self,
                x,
                encoder_out=None,
                encoder_padding_mask=None,
                self_attn_mask=None,
                self_attn_padding_mask=None,
                pos_key=None,
                pos_val=None,
                **kwargs):
        pos_key, pos_val = self._get_pos_repr(x, pos_key, pos_val)
        x, attn = self.self_attn_block(x=x,
                                       padding_mask=self_attn_padding_mask,
                                       attn_mask=self_attn_mask,
                                       ret_attn=True,
                                       need_weights=False,
                                       pos_key=pos_key,
                                       pos_val=pos_val)
        if self.pos_self_attn_block is not None:
            pos_encoding, weights = positional_encodings_like(x), None
            x = self.pos_self_attn_block(
                x=pos_encoding,
                value=x,
                padding_mask=self_attn_padding_mask,
                attn_mask=self_attn_mask,
                need_weights=False,
            )

        if self.encoder_attn_block is not None:
            x, attn = self.encoder_attn_block(
                query=x,
                key=encoder_out,
                value=encoder_out,
                padding_mask=encoder_padding_mask,
                ret_attn=True,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
        x = self.ffn_block(x)
        return x, attn
Esempio n. 6
0
    def wrap_position(self, pos_query, candidate):
        """
             return the position for matching
        :param pos_query: batch, query_len, state_dim
        :param candidate: batch, target_len, state_dim
        :return:
        """
        # batch_size, target_len, state_dim
        if self.position_type == 0:
            return 0, 0
        position_candidate = positional_encodings_like(candidate)
        if self.position_type == 1:
            return 0, position_candidate
        if self.normalize:
            # batch, query_len, target_len
            position_logits = F.cosine_similarity(pos_query[:, None, :, :], position_candidate[:, :, None, :], dim=-1)
        else:
            # batch, query_len, target_len
            position_logits = torch.matmul(pos_query, position_candidate.contiguous().transpose(1, 2))

        # batch_size, query_len, state_dim
        position_query = torch.matmul(position_logits, position_candidate)
        return position_query, position_candidate
Esempio n. 7
0
    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        incremental_state=None,
        prev_self_attn_state=None,
        prev_attn_state=None,
        self_attn_mask=None,
        self_attn_padding_mask=None,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        if prev_self_attn_state is not None:
            if incremental_state is None:
                incremental_state = {}
            prev_key, prev_value = prev_self_attn_state
            saved_state = {"prev_key": prev_key, "prev_value": prev_value}
            self.self_attn.set_input_buffer(incremental_state, saved_state)
        residual = x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
        x, attn = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)

        # if self.pos_self_attn is not None:
        #     residual = x
        #     pos_encoding, weights = positional_encodings_like(x), None
        #     x, attn = self.pos_self_attn(
        #         query=pos_encoding,
        #         key=pos_encoding,
        #         value=x,
        #         key_padding_mask=self_attn_padding_mask,
        #         need_weights=False,
        #         attn_mask=self_attn_mask,
        #     )
        #     x = F.dropout(x, p=self.dropout, training=self.training)
        #     x = residual + x
        #     x = self.maybe_layer_norm(self.pos_self_attn_layer_norm, x, after=True)
        if self.pos_self_attn_block is not None:
            pos_encoding, weights = positional_encodings_like(x), None
            x = self.pos_self_attn_block(
                x=pos_encoding,
                value=x,
                padding_mask=self_attn_padding_mask,
                attn_mask=self_attn_mask,
                need_weights=False,
            )

        if self.encoder_attn is not None:
            if prev_attn_state is not None:
                if incremental_state is None:
                    incremental_state = {}
                prev_key, prev_value = prev_attn_state
                saved_state = {"prev_key": prev_key, "prev_value": prev_value}
                self.encoder_attn.set_input_buffer(incremental_state,
                                                   saved_state)

            residual = x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      before=True)
            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=(not self.training and self.need_attn),
            )
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = residual + x
            x = self.maybe_layer_norm(self.encoder_attn_layer_norm,
                                      x,
                                      after=True)

        residual = x
        x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
        if self.onnx_trace and incremental_state is not None:
            saved_state = self.self_attn.get_input_buffer(incremental_state)
            self_attn_state = saved_state["prev_key"], saved_state[
                "prev_value"]
            return x, attn, self_attn_state
        return x, attn