Exemplo n.º 1
0
    def cross_attention_sublayer(self, queries: tf.Tensor) -> tf.Tensor:
        assert self.cross_attention_sublayer is not None
        encoder_att_states = get_attention_states(
            self.input_for_cross_attention)
        encoder_att_mask = get_attention_mask(self.input_for_cross_attention)

        # Layer normalization
        normalized_queries = layer_norm(queries)

        encoder_context, _ = attention(
            queries=normalized_queries,
            keys=encoder_att_states,
            values=encoder_att_states,
            keys_mask=encoder_att_mask,
            num_heads=self.n_cross_att_heads,
            dropout_callback=lambda x: dropout(
                x, self.attention_dropout_keep_prob, self.train_mode),
            use_bias=self.use_att_transform_bias)

        # Apply dropout
        encoder_context = dropout(encoder_context, self.dropout_keep_prob,
                                  self.train_mode)

        # Add residual connections
        return encoder_context + queries
Exemplo n.º 2
0
    def self_attention_sublayer(
            self, prev_layer: TransformerLayer) -> tf.Tensor:
        """Create the decoder self-attention sublayer with output mask."""

        # Layer normalization
        normalized_states = layer_norm(prev_layer.temporal_states)

        # Run self-attention
        # TODO handle attention histories
        self_context, _ = attention(
            queries=normalized_states,
            keys=normalized_states,
            values=normalized_states,
            keys_mask=prev_layer.temporal_mask,
            num_heads=self.n_heads_self,
            masked=True,
            dropout_callback=lambda x: dropout(
                x, self.self_att_dropout_keep_prob, self.train_mode),
            use_bias=self.use_att_transform_bias)

        # Apply dropout
        self_context = dropout(
            self_context, self.dropout_keep_prob, self.train_mode)

        # Add residual connections
        return self_context + prev_layer.temporal_states
Exemplo n.º 3
0
    def self_attention_sublayer(self,
                                prev_layer: TransformerLayer) -> tf.Tensor:
        """Create the decoder self-attention sublayer with output mask."""

        # Layer normalization
        normalized_states = layer_norm(prev_layer.temporal_states)

        # Run self-attention
        # TODO handle attention histories
        self_context, _ = attention(
            queries=normalized_states,
            keys=normalized_states,
            values=normalized_states,
            keys_mask=prev_layer.temporal_mask,
            num_heads=self.n_heads_self,
            masked=True,
            dropout_callback=lambda x: dropout(
                x, self.attention_dropout_keep_prob, self.train_mode),
            use_bias=self.use_att_transform_bias)

        # Apply dropout
        self_context = dropout(self_context, self.dropout_keep_prob,
                               self.train_mode)

        # Add residual connections
        return self_context + prev_layer.temporal_states
Exemplo n.º 4
0
    def encoder_attention_sublayer(self, queries: tf.Tensor) -> tf.Tensor:
        """Create the encoder-decoder attention sublayer."""

        encoder_att_states = get_attention_states(self.encoder)
        encoder_att_mask = get_attention_mask(self.encoder)

        # Layer normalization
        normalized_queries = layer_norm(queries)

        # Attend to the encoder
        # TODO handle histories
        encoder_context, _ = attention(
            queries=normalized_queries,
            keys=encoder_att_states,
            values=encoder_att_states,
            keys_mask=encoder_att_mask,
            num_heads=self.n_heads_enc,
            dropout_callback=lambda x: dropout(
                x, self.attention_dropout_keep_prob, self.train_mode),
            use_bias=self.use_att_transform_bias)

        # Apply dropout
        encoder_context = dropout(encoder_context, self.dropout_keep_prob,
                                  self.train_mode)

        # Add residual connections
        return encoder_context + queries
Exemplo n.º 5
0
    def cross_attention_sublayer(self, queries: tf.Tensor) -> tf.Tensor:
        assert self.cross_attention_sublayer is not None
        assert self.n_cross_att_heads is not None
        assert self.input_for_cross_attention is not None

        encoder_att_states = get_attention_states(
            self.input_for_cross_attention)
        encoder_att_mask = get_attention_mask(self.input_for_cross_attention)

        # Layer normalization
        normalized_queries = layer_norm(queries)

        encoder_context, _ = attention(
            queries=normalized_queries,
            keys=encoder_att_states,
            values=encoder_att_states,
            keys_mask=encoder_att_mask,
            num_heads=self.n_cross_att_heads,
            dropout_callback=lambda x: dropout(
                x, self.attention_dropout_keep_prob, self.train_mode),
            use_bias=self.use_att_transform_bias)

        # Apply dropout
        encoder_context = dropout(
            encoder_context, self.dropout_keep_prob, self.train_mode)

        # Add residual connections
        return encoder_context + queries
Exemplo n.º 6
0
def single(
        queries: tf.Tensor,
        states: tf.Tensor,
        mask: tf.Tensor,
        n_heads: int,
        attention_dropout_callback: Callable[[tf.Tensor], tf.Tensor],
        dropout_callback: Callable[[tf.Tensor], tf.Tensor],
        normalize: bool = True,
        use_dropout: bool = True,
        residual: bool = True,
        use_att_transform_bias: bool = False):
    """Run attention on a single encoder.

    Arguments:
        queries: The input for the attention.
        states: The encoder states (keys & values).
        mask: The temporal mask of the encoder.
        n_heads: Number of attention heads to use.
        attention_dropout_callback: Dropout function to apply in attention.
        dropout_callback: Dropout function to apply on the attention output.
        normalize: If True, run layer normalization on the queries.
        use_dropout: If True, perform dropout on the attention output.
        residual: If True, sum the context vector with the input queries.
        use_att_transform_bias: If True, enable bias in the attention head
            projections (for all queries, keys and values).

    Returns:
        A Tensor that contains the context vector.
    """

    # Layer normalization
    normalized_queries = layer_norm(queries) if normalize else queries

    # Attend to the encoder
    # TODO handle attention histories
    encoder_context, _ = attention(
        queries=normalized_queries,
        keys=states,
        values=states,
        keys_mask=mask,
        num_heads=n_heads,
        dropout_callback=attention_dropout_callback,
        use_bias=use_att_transform_bias)

    # Apply dropout
    if use_dropout:
        encoder_context = dropout_callback(encoder_context)

    # Add residual connections
    if residual:
        encoder_context += queries

    return encoder_context
def single(queries: tf.Tensor,
           states: tf.Tensor,
           mask: tf.Tensor,
           n_heads: int,
           attention_dropout_callback: Callable[[tf.Tensor], tf.Tensor],
           dropout_callback: Callable[[tf.Tensor], tf.Tensor],
           normalize: bool = True,
           use_dropout: bool = True,
           residual: bool = True,
           use_att_transform_bias: bool = False):
    """Run attention on a single encoder.

    Arguments:
        queries: The input for the attention.
        states: The encoder states (keys & values).
        mask: The temporal mask of the encoder.
        n_heads: Number of attention heads to use.
        attention_dropout_callback: Dropout function to apply in attention.
        dropout_callback: Dropout function to apply on the attention output.
        normalize: If True, run layer normalization on the queries.
        use_dropout: If True, perform dropout on the attention output.
        residual: If True, sum the context vector with the input queries.
        use_att_transform_bias: If True, enable bias in the attention head
            projections (for all queries, keys and values).

    Returns:
        A Tensor that contains the context vector.
    """

    # Layer normalization
    normalized_queries = layer_norm(queries) if normalize else queries

    # Attend to the encoder
    # TODO handle attention histories
    encoder_context, _ = attention(queries=normalized_queries,
                                   keys=states,
                                   values=states,
                                   keys_mask=mask,
                                   num_heads=n_heads,
                                   dropout_callback=attention_dropout_callback,
                                   use_bias=use_att_transform_bias)

    # Apply dropout
    if use_dropout:
        encoder_context = dropout_callback(encoder_context)

    # Add residual connections
    if residual:
        encoder_context += queries

    return encoder_context
Exemplo n.º 8
0
    def self_attention(self, level: int,
                       prev_layer: TransformerLayer) -> tf.Tensor:

        with tf.variable_scope("self_attention_{}".format(level)):
            self_context, _ = attention(
                queries=prev_layer.temporal_states,
                keys=prev_layer.temporal_states,
                values=prev_layer.temporal_states,
                keys_mask=prev_layer.temporal_mask,
                num_heads=self.n_heads,
                dropout_callback=lambda x: dropout(
                    x, self.attention_dropout_keep_prob, self.train_mode))

            return dropout(self_context, self.dropout_keep_prob,
                           self.train_mode)
Exemplo n.º 9
0
    def masked_self_attention(
            self, level: int, prev_layer: TransformerLayer) -> tf.Tensor:

        with tf.variable_scope("dec_self_att_level_{}".format(level),
                               reuse=tf.AUTO_REUSE):
            # TODO handle histories
            self_context, _ = attention(
                queries=prev_layer.temporal_states,
                keys=prev_layer.temporal_states,
                values=prev_layer.temporal_states,
                keys_mask=prev_layer.temporal_mask,
                num_heads=self.n_heads_self,
                masked=True,
                dropout_callback=lambda x: dropout(
                    x, self.attention_dropout_keep_prob, self.train_mode))

            return dropout(
                self_context, self.dropout_keep_prob, self.train_mode)
Exemplo n.º 10
0
    def encoder_attention(self, level: int, queries: tf.Tensor) -> tf.Tensor:

        with tf.variable_scope("dec_inter_att_level_{}".format(level),
                               reuse=tf.AUTO_REUSE):
            encoder_att_states = get_attention_states(self.encoder)
            encoder_att_mask = get_attention_mask(self.encoder)

            # TODO handle histories
            encoder_context, _ = attention(
                queries=queries,
                keys=encoder_att_states,
                values=encoder_att_states,
                keys_mask=encoder_att_mask,
                num_heads=self.n_heads_enc,
                dropout_callback=lambda x: dropout(
                    x, self.attention_dropout_keep_prob, self.train_mode))

            return dropout(
                encoder_context, self.dropout_keep_prob, self.train_mode)