예제 #1
0
    def call(
        self,
        tgt: np.ndarray,
        memory: np.ndarray,
        tgt_mask: tp.Optional[np.ndarray] = None,
        memory_mask: tp.Optional[np.ndarray] = None,
        # tgt_key_padding_mask: tp.Optional[np.ndarray] = None,
        # memory_key_padding_mask: tp.Optional[np.ndarray] = None,
    ) -> np.ndarray:
        r"""Pass the inputs (and mask) through the decoder layer.

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequence from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        # Implementation of Feedforward model

        tgt2 = MultiHeadAttention(self.head_size,
                                  self.num_heads,
                                  dropout=self.dropout)(tgt, mask=tgt_mask)
        tgt = tgt + Dropout(self.dropout)(tgt2)
        tgt = LayerNormalization()(tgt)
        tgt2 = MultiHeadAttention(self.head_size,
                                  self.num_heads,
                                  dropout=self.dropout)(
                                      tgt,
                                      memory,
                                      mask=memory_mask,
                                  )
        tgt = tgt + Dropout(self.dropout)(tgt2)
        tgt = LayerNormalization()(tgt)
        tgt = tgt + sequential(
            Linear(self.output_size),
            self.activation,
            Dropout(self.dropout),
            Linear(self.output_size),
            Dropout(self.dropout),
        )(tgt)
        tgt = LayerNormalization()(tgt)
        return tgt
예제 #2
0
    def call(
        self,
        src: np.ndarray,
        mask: tp.Optional[np.ndarray] = None,
        # src_key_padding_mask: tp.Optional[np.ndarray] = None,
    ) -> np.ndarray:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        # Implementation of Feedforward model

        output_size: int = (
            self.output_size if self.output_size is not None else src.shape[-1]
        )

        src2 = MultiHeadAttention(
            self.head_size,
            self.num_heads,
            dropout=self.dropout,
        )(src, mask=mask)
        src = src + Dropout(self.dropout)(src2)
        src = LayerNormalization()(src)
        src2 = sequential(
            Linear(output_size),
            self.activation,
            Dropout(self.dropout),
            Linear(output_size),
        )(src)
        src = src + Dropout(self.dropout)(src2)
        src = LayerNormalization()(src)
        return src
예제 #3
0
    def call(
        self,
        src: np.ndarray,
        tgt: np.ndarray,
        src_mask: tp.Optional[np.ndarray] = None,
        tgt_mask: tp.Optional[np.ndarray] = None,
        memory_mask: tp.Optional[np.ndarray] = None,
        # src_key_padding_mask: tp.Optional[np.ndarray] = None,
        # tgt_key_padding_mask: tp.Optional[np.ndarray] = None,
        # memory_key_padding_mask: tp.Optional[np.ndarray] = None,
    ) -> np.ndarray:
        r"""Take in and process masked source/target sequences.

        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
            src_mask: the additive mask for the src sequence (optional).
            tgt_mask: the additive mask for the tgt sequence (optional).
            memory_mask: the additive mask for the encoder output (optional).
            src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
            tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
            memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).

        Shape:
            - src: :math:`(S, N, E)`.
            - tgt: :math:`(T, N, E)`.
            - src_mask: :math:`(S, S)`.
            - tgt_mask: :math:`(T, T)`.
            - memory_mask: :math:`(T, S)`.
            - src_key_padding_mask: :math:`(N, S)`.
            - tgt_key_padding_mask: :math:`(N, T)`.
            - memory_key_padding_mask: :math:`(N, S)`.

            Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
            is provided, it will be added to the attention weight.
            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
            the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
            positions will be unchanged. If a BoolTensor is provided, the positions with the
            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.

            - output: :math:`(T, N, E)`.

            Note: Due to the multi-head attention architecture in the transformer model,
            the output sequence length of a transformer is same as the input sequence
            (i.e. target) length of the decode.

            where S is the source sequence length, T is the target sequence length, N is the
            batch size, E is the feature number

        Examples:
            >>> # output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        """

        if src.shape[0] != tgt.shape[0]:
            raise RuntimeError("the batch number of src and tgt must be equal")

        # if src.shape[2] != self.head_size or tgt.shape[2] != self.head_size:
        #     raise RuntimeError(
        #         "the feature number of src and tgt must be equal to head_size"
        #     )

        if self.custom_encoder is not None:
            encoder = self.custom_encoder()
        else:
            encoder = TransformerEncoder(
                lambda: TransformerEncoderLayer(
                    self.head_size,
                    self.num_heads,
                    self.output_size,
                    self.dropout,
                    self.activation,
                ),
                self.num_encoder_layers,
                lambda: LayerNormalization(),
            )

        if self.custom_decoder is not None:
            decoder = self.custom_decoder()
        else:
            decoder = TransformerDecoder(
                lambda: TransformerDecoderLayer(
                    self.head_size,
                    self.num_heads,
                    self.output_size,
                    self.dropout,
                    self.activation,
                ),
                self.num_decoder_layers,
                lambda: LayerNormalization(),
            )

        memory = encoder(
            src,
            mask=src_mask,
            # src_key_padding_mask=src_key_padding_mask
        )
        output = decoder(
            tgt,
            memory,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
            # tgt_key_padding_mask=tgt_key_padding_mask,
            # memory_key_padding_mask=memory_key_padding_mask,
        )

        return output