def layer(self, level: int) -> TransformerLayer: # Recursive implementation. Outputs of the zeroth layer # are normalized inputs. if level == 0: return TransformerLayer(self.encoder_inputs, self.temporal_mask) # Compute the outputs of the previous layer prev_layer = self.layer(level - 1) with tf.variable_scope("layer_{}".format(level - 1)): with tf.variable_scope("self_attention"): self_context = self.self_attention_sublayer(prev_layer) if self.input_for_cross_attention is not None: with tf.variable_scope("cross_attention"): self_context = self.cross_attention_sublayer(self_context) with tf.variable_scope("feedforward"): output_states = self.feedforward_sublayer(self_context) # Layer normalization on the encoder outputs if self.depth == level: output_states = layer_norm(output_states) return TransformerLayer(states=output_states, mask=self.temporal_mask)
def encoder_attention_sublayer(self, queries: tf.Tensor) -> tf.Tensor: """Create the encoder-decoder attention sublayer.""" encoder_att_states = get_attention_states(self.encoder) encoder_att_mask = get_attention_mask(self.encoder) # Layer normalization normalized_queries = layer_norm(queries) # Attend to the encoder # TODO handle histories encoder_context, _ = attention( queries=normalized_queries, keys=encoder_att_states, values=encoder_att_states, keys_mask=encoder_att_mask, num_heads=self.n_heads_enc, dropout_callback=lambda x: dropout( x, self.attention_dropout_keep_prob, self.train_mode), use_bias=self.use_att_transform_bias) # Apply dropout encoder_context = dropout(encoder_context, self.dropout_keep_prob, self.train_mode) # Add residual connections return encoder_context + queries
def layer(self, level: int, inputs: tf.Tensor, mask: tf.Tensor) -> TransformerLayer: # Recursive implementation. Outputs of the zeroth layer # are the inputs if level == 0: return TransformerLayer(inputs, mask) # Compute the outputs of the previous layer prev_layer = self.layer(level - 1, inputs, mask) with tf.variable_scope("layer_{}".format(level - 1)): with tf.variable_scope("self_attention"): self_context = self.self_attention_sublayer(prev_layer) with tf.variable_scope("encdec_attention"): encoder_context = self.encoder_attention_sublayer(self_context) with tf.variable_scope("feedforward"): output_states = self.feedforward_sublayer(encoder_context) # Layer normalization on the decoder output if self.depth == level: output_states = layer_norm(output_states) return TransformerLayer(states=output_states, mask=mask)
def cross_attention_sublayer(self, queries: tf.Tensor) -> tf.Tensor: assert self.cross_attention_sublayer is not None encoder_att_states = get_attention_states( self.input_for_cross_attention) encoder_att_mask = get_attention_mask(self.input_for_cross_attention) # Layer normalization normalized_queries = layer_norm(queries) encoder_context, _ = attention( queries=normalized_queries, keys=encoder_att_states, values=encoder_att_states, keys_mask=encoder_att_mask, num_heads=self.n_cross_att_heads, dropout_callback=lambda x: dropout( x, self.attention_dropout_keep_prob, self.train_mode), use_bias=self.use_att_transform_bias) # Apply dropout encoder_context = dropout(encoder_context, self.dropout_keep_prob, self.train_mode) # Add residual connections return encoder_context + queries
def cross_attention_sublayer(self, queries: tf.Tensor) -> tf.Tensor: assert self.cross_attention_sublayer is not None assert self.n_cross_att_heads is not None assert self.input_for_cross_attention is not None encoder_att_states = get_attention_states( self.input_for_cross_attention) encoder_att_mask = get_attention_mask(self.input_for_cross_attention) # Layer normalization normalized_queries = layer_norm(queries) encoder_context, _ = attention( queries=normalized_queries, keys=encoder_att_states, values=encoder_att_states, keys_mask=encoder_att_mask, num_heads=self.n_cross_att_heads, dropout_callback=lambda x: dropout( x, self.attention_dropout_keep_prob, self.train_mode), use_bias=self.use_att_transform_bias) # Apply dropout encoder_context = dropout( encoder_context, self.dropout_keep_prob, self.train_mode) # Add residual connections return encoder_context + queries
def feedforward_sublayer(self, layer_input: tf.Tensor) -> tf.Tensor: """Create the feed-forward network sublayer.""" # Layer normalization normalized_input = layer_norm(layer_input) # Feed-forward network hidden layer + ReLU ff_hidden = tf.layers.dense(normalized_input, self.ff_hidden_size, activation=tf.nn.relu, name="hidden_state") # Apply dropout on hidden layer activations ff_hidden = dropout(ff_hidden, self.dropout_keep_prob, self.train_mode) # Feed-forward output projection ff_output = tf.layers.dense(ff_hidden, self.model_dimension, name="output") # Apply dropout on feed-forward output projection ff_output = dropout(ff_output, self.dropout_keep_prob, self.train_mode) # Add residual connections return ff_output + layer_input
def self_attention_sublayer(self, prev_layer: TransformerLayer) -> tf.Tensor: """Create the decoder self-attention sublayer with output mask.""" # Layer normalization normalized_states = layer_norm(prev_layer.temporal_states) # Run self-attention # TODO handle attention histories self_context, _ = attention( queries=normalized_states, keys=normalized_states, values=normalized_states, keys_mask=prev_layer.temporal_mask, num_heads=self.n_heads_self, masked=True, dropout_callback=lambda x: dropout( x, self.attention_dropout_keep_prob, self.train_mode), use_bias=self.use_att_transform_bias) # Apply dropout self_context = dropout(self_context, self.dropout_keep_prob, self.train_mode) # Add residual connections return self_context + prev_layer.temporal_states
def self_attention_sublayer( self, prev_layer: TransformerLayer) -> tf.Tensor: """Create the decoder self-attention sublayer with output mask.""" # Layer normalization normalized_states = layer_norm(prev_layer.temporal_states) # Run self-attention # TODO handle attention histories self_context, _ = attention( queries=normalized_states, keys=normalized_states, values=normalized_states, keys_mask=prev_layer.temporal_mask, num_heads=self.n_heads_self, masked=True, dropout_callback=lambda x: dropout( x, self.self_att_dropout_keep_prob, self.train_mode), use_bias=self.use_att_transform_bias) # Apply dropout self_context = dropout( self_context, self.dropout_keep_prob, self.train_mode) # Add residual connections return self_context + prev_layer.temporal_states
def single( queries: tf.Tensor, states: tf.Tensor, mask: tf.Tensor, n_heads: int, attention_dropout_callback: Callable[[tf.Tensor], tf.Tensor], dropout_callback: Callable[[tf.Tensor], tf.Tensor], normalize: bool = True, use_dropout: bool = True, residual: bool = True, use_att_transform_bias: bool = False): """Run attention on a single encoder. Arguments: queries: The input for the attention. states: The encoder states (keys & values). mask: The temporal mask of the encoder. n_heads: Number of attention heads to use. attention_dropout_callback: Dropout function to apply in attention. dropout_callback: Dropout function to apply on the attention output. normalize: If True, run layer normalization on the queries. use_dropout: If True, perform dropout on the attention output. residual: If True, sum the context vector with the input queries. use_att_transform_bias: If True, enable bias in the attention head projections (for all queries, keys and values). Returns: A Tensor that contains the context vector. """ # Layer normalization normalized_queries = layer_norm(queries) if normalize else queries # Attend to the encoder # TODO handle attention histories encoder_context, _ = attention( queries=normalized_queries, keys=states, values=states, keys_mask=mask, num_heads=n_heads, dropout_callback=attention_dropout_callback, use_bias=use_att_transform_bias) # Apply dropout if use_dropout: encoder_context = dropout_callback(encoder_context) # Add residual connections if residual: encoder_context += queries return encoder_context
def single(queries: tf.Tensor, states: tf.Tensor, mask: tf.Tensor, n_heads: int, attention_dropout_callback: Callable[[tf.Tensor], tf.Tensor], dropout_callback: Callable[[tf.Tensor], tf.Tensor], normalize: bool = True, use_dropout: bool = True, residual: bool = True, use_att_transform_bias: bool = False): """Run attention on a single encoder. Arguments: queries: The input for the attention. states: The encoder states (keys & values). mask: The temporal mask of the encoder. n_heads: Number of attention heads to use. attention_dropout_callback: Dropout function to apply in attention. dropout_callback: Dropout function to apply on the attention output. normalize: If True, run layer normalization on the queries. use_dropout: If True, perform dropout on the attention output. residual: If True, sum the context vector with the input queries. use_att_transform_bias: If True, enable bias in the attention head projections (for all queries, keys and values). Returns: A Tensor that contains the context vector. """ # Layer normalization normalized_queries = layer_norm(queries) if normalize else queries # Attend to the encoder # TODO handle attention histories encoder_context, _ = attention(queries=normalized_queries, keys=states, values=states, keys_mask=mask, num_heads=n_heads, dropout_callback=attention_dropout_callback, use_bias=use_att_transform_bias) # Apply dropout if use_dropout: encoder_context = dropout_callback(encoder_context) # Add residual connections if residual: encoder_context += queries return encoder_context
def rnn(self) -> Tuple[tf.Tensor, tf.Tensor]: layer_input = self.rnn_input # type: tf.Tensor # pylint: disable=unsubscriptable-object layer_final = self.rnn_input[:, -1] # pylint: enable=unsubscriptable-object for i, rnn_spec in enumerate(self.rnn_specs): with tf.variable_scope("rnn_{}_{}".format(i, rnn_spec.direction), reuse=tf.AUTO_REUSE): if self.add_layer_norm: layer_input = layer_norm(layer_input) layer_output, layer_final_output = rnn_layer( layer_input, self.input_sequence.lengths, rnn_spec) layer_output = dropout(layer_output, self.dropout_keep_prob, self.train_mode) layer_final_output = dropout(layer_final_output, self.dropout_keep_prob, self.train_mode) in_dim = layer_input.get_shape()[-1] out_dim = layer_output.get_shape()[-1] if self.add_residual and in_dim == out_dim: layer_input += layer_output layer_final += layer_final_output else: # pylint: disable=redefined-variable-type layer_input = layer_output layer_final = layer_final_output # pylint: enable=redefined-variable-type assert layer_final is not None if self.include_final_layer_norm: return layer_norm(layer_input), layer_norm(layer_final) return layer_input, layer_final
def rnn(self) -> Tuple[tf.Tensor, tf.Tensor]: layer_input = self.rnn_input # type: tf.Tensor # pylint: disable=unsubscriptable-object layer_final = self.rnn_input[:, -1] # pylint: enable=unsubscriptable-object for i, rnn_spec in enumerate(self.rnn_specs): with tf.variable_scope("rnn_{}_{}".format(i, rnn_spec.direction), reuse=tf.AUTO_REUSE): if self.add_layer_norm: layer_input = layer_norm(layer_input) layer_output, layer_final_output = rnn_layer( layer_input, self.input_sequence.lengths, rnn_spec) layer_output = dropout( layer_output, self.dropout_keep_prob, self.train_mode) layer_final_output = dropout( layer_final_output, self.dropout_keep_prob, self.train_mode) in_dim = layer_input.get_shape()[-1] out_dim = layer_output.get_shape()[-1] if self.add_residual and in_dim == out_dim: layer_input += layer_output layer_final += layer_final_output else: # pylint: disable=redefined-variable-type layer_input = layer_output layer_final = layer_final_output # pylint: enable=redefined-variable-type assert layer_final is not None if self.include_final_layer_norm: return layer_norm(layer_input), layer_norm(layer_final) return layer_input, layer_final
def parallel(queries: tf.Tensor, encoder_states: List[tf.Tensor], encoder_masks: List[tf.Tensor], heads: List[int], attention_dropout_callbacks: List[Callable[[tf.Tensor], tf.Tensor]], dropout_callback: Callable[[tf.Tensor], tf.Tensor]) -> tf.Tensor: """Run attention with parallel input combination. The procedure is as follows: 1. normalize queries, 2. attend and dropout independently for every encoder, 3. sum up the results 4. add residual and return Arguments: queries: The input for the attention. encoder_states: The states of each encoder. encoder_masks: The temporal mask of each encoder. heads: Number of attention heads to use for each encoder. attention_dropout_callbacks: Dropout functions to apply in attention over each encoder. dropout_callback: The dropout function to apply on the outputs of each sub-attention. Returns: A Tensor that contains the context vector. """ normalized_queries = layer_norm(queries) contexts = [] for i, (states, mask, n_heads, attn_drop_cb) in enumerate( zip(encoder_states, encoder_masks, heads, attention_dropout_callbacks)): with tf.variable_scope("enc_{}".format(i)): contexts.append( single(normalized_queries, states, mask, n_heads, attention_dropout_callback=attn_drop_cb, dropout_callback=dropout_callback, normalize=False, residual=False)) return sum(contexts) + queries
def parallel( queries: tf.Tensor, encoder_states: List[tf.Tensor], encoder_masks: List[tf.Tensor], heads: List[int], attention_dropout_callbacks: List[Callable[[tf.Tensor], tf.Tensor]], dropout_callback: Callable[[tf.Tensor], tf.Tensor]) -> tf.Tensor: """Run attention with parallel input combination. The procedure is as follows: 1. normalize queries, 2. attend and dropout independently for every encoder, 3. sum up the results 4. add residual and return Arguments: queries: The input for the attention. encoder_states: The states of each encoder. encoder_masks: The temporal mask of each encoder. heads: Number of attention heads to use for each encoder. attention_dropout_callbacks: Dropout functions to apply in attention over each encoder. dropout_callback: The dropout function to apply on the outputs of each sub-attention. Returns: A Tensor that contains the context vector. """ normalized_queries = layer_norm(queries) contexts = [] for i, (states, mask, n_heads, attn_drop_cb) in enumerate(zip( encoder_states, encoder_masks, heads, attention_dropout_callbacks)): with tf.variable_scope("enc_{}".format(i)): contexts.append( single(normalized_queries, states, mask, n_heads, attention_dropout_callback=attn_drop_cb, dropout_callback=dropout_callback, normalize=False, residual=False)) return sum(contexts) + queries
def feedforward_sublayer(self, layer_input: tf.Tensor) -> tf.Tensor: """Create the feed-forward network sublayer.""" # Layer normalization normalized_input = layer_norm(layer_input) # Feed-forward network hidden layer + ReLU ff_hidden = tf.layers.dense( normalized_input, self.ff_hidden_size, activation=tf.nn.relu, name="hidden_state") # Apply dropout on the activations ff_hidden = dropout(ff_hidden, self.dropout_keep_prob, self.train_mode) # Feed-forward output projection ff_output = tf.layers.dense(ff_hidden, self.dimension, name="output") # Apply dropout on the output projection ff_output = dropout(ff_output, self.dropout_keep_prob, self.train_mode) # Add residual connections return ff_output + layer_input
def hierarchical( queries: tf.Tensor, encoder_states: List[tf.Tensor], encoder_masks: List[tf.Tensor], heads: List[int], heads_hier: int, attention_dropout_callbacks: List[Callable[[tf.Tensor], tf.Tensor]], dropout_callback: Callable[[tf.Tensor], tf.Tensor]) -> tf.Tensor: """Run attention with hierarchical input combination. The procedure is as follows: 1. normalize queries 2. attend to every encoder 3. attend to the resulting context vectors (reuse normalized queries) 4. apply dropout, add residual connection and return Arguments: queries: The input for the attention. encoder_states: The states of each encoder. encoder_masks: The temporal mask of each encoder. heads: Number of attention heads to use for each encoder. heads_hier: Number of attention heads to use in the second attention. attention_dropout_callbacks: Dropout functions to apply in attention over each encoder. dropout_callback: The dropout function to apply in the second attention and over the outputs of each sub-attention. Returns: A Tensor that contains the context vector. """ normalized_queries = layer_norm(queries) contexts = [] batch = tf.shape(queries)[0] time_q = tf.shape(queries)[1] dimension = tf.shape(queries)[2] for i, (states, mask, n_heads, attn_drop_cb) in enumerate( zip(encoder_states, encoder_masks, heads, attention_dropout_callbacks)): with tf.variable_scope("enc_{}".format(i)): contexts.append( single(normalized_queries, states, mask, n_heads, attention_dropout_callback=attn_drop_cb, dropout_callback=dropout_callback, normalize=False, residual=False)) # context is of shape [batch, time(q), channels(v)], # stack to [batch, time(q), n_encoders, channels(v)] # reshape to [batch x time(q), n_encoders, channels(v)] stacked_contexts = tf.reshape( tf.stack(contexts, axis=2), [batch * time_q, len(encoder_states), dimension]) # hierarchical mask: ones of shape [batch x time(q), n_encoders] hier_mask = tf.ones([batch * time_q, len(encoder_states)]) # reshape queries to [batch x time(q), 1, channels(v)] reshaped_queries = tf.reshape(normalized_queries, [batch * time_q, 1, dimension]) # returned shape [batch x time(q), 1, channels(v)] with tf.variable_scope("enc_hier"): # NOTE as attention dropout keep probability, we use the # dropout_keep_prob value instead of attention_dropout_keep_prob. encoder_context_stacked_batch = single( reshaped_queries, stacked_contexts, hier_mask, heads_hier, attention_dropout_callback=dropout_callback, dropout_callback=lambda x: x, normalize=False, use_dropout=False, residual=False) # reshape back to [batch, time(q), channels(v)] encoder_context = tf.reshape(encoder_context_stacked_batch, [batch, time_q, dimension]) encoder_context = dropout_callback(encoder_context) return encoder_context + queries
def hierarchical( queries: tf.Tensor, encoder_states: List[tf.Tensor], encoder_masks: List[tf.Tensor], heads: List[int], heads_hier: int, attention_dropout_callbacks: List[Callable[[tf.Tensor], tf.Tensor]], dropout_callback: Callable[[tf.Tensor], tf.Tensor]) -> tf.Tensor: """Run attention with hierarchical input combination. The procedure is as follows: 1. normalize queries 2. attend to every encoder 3. attend to the resulting context vectors (reuse normalized queries) 4. apply dropout, add residual connection and return Arguments: queries: The input for the attention. encoder_states: The states of each encoder. encoder_masks: The temporal mask of each encoder. heads: Number of attention heads to use for each encoder. heads_hier: Number of attention heads to use in the second attention. attention_dropout_callbacks: Dropout functions to apply in attention over each encoder. dropout_callback: The dropout function to apply in the second attention and over the outputs of each sub-attention. Returns: A Tensor that contains the context vector. """ normalized_queries = layer_norm(queries) contexts = [] batch = tf.shape(queries)[0] time_q = tf.shape(queries)[1] dimension = tf.shape(queries)[2] for i, (states, mask, n_heads, attn_drop_cb) in enumerate(zip( encoder_states, encoder_masks, heads, attention_dropout_callbacks)): with tf.variable_scope("enc_{}".format(i)): contexts.append( single(normalized_queries, states, mask, n_heads, attention_dropout_callback=attn_drop_cb, dropout_callback=dropout_callback, normalize=False, residual=False)) # context is of shape [batch, time(q), channels(v)], # stack to [batch, time(q), n_encoders, channels(v)] # reshape to [batch x time(q), n_encoders, channels(v)] stacked_contexts = tf.reshape( tf.stack(contexts, axis=2), [batch * time_q, len(encoder_states), dimension]) # hierarchical mask: ones of shape [batch x time(q), n_encoders] hier_mask = tf.ones([batch * time_q, len(encoder_states)]) # reshape queries to [batch x time(q), 1, channels(v)] reshaped_queries = tf.reshape( normalized_queries, [batch * time_q, 1, dimension]) # returned shape [batch x time(q), 1, channels(v)] with tf.variable_scope("enc_hier"): # NOTE as attention dropout keep probability, we use the # dropout_keep_prob value instead of attention_dropout_keep_prob. encoder_context_stacked_batch = single( reshaped_queries, stacked_contexts, hier_mask, heads_hier, attention_dropout_callback=dropout_callback, dropout_callback=lambda x: x, normalize=False, use_dropout=False, residual=False) # reshape back to [batch, time(q), channels(v)] encoder_context = tf.reshape( encoder_context_stacked_batch, [batch, time_q, dimension]) encoder_context = dropout_callback(encoder_context) return encoder_context + queries