def transformer_decoder(decoder_input, encoder_output, residual_fn, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, name="decoder"): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor residual_fn: a function from (layer_input, layer_output) -> combined_output decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ x = decoder_input # Summaries don't work in multi-problem setting yet. summaries = "problems" not in hparams.values() or len( hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): x = residual_fn( x, common_attention.multihead_attention( x, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, summaries=summaries, name="decoder_self_attention")) x = residual_fn( x, common_attention.multihead_attention( x, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, summaries=summaries, name="encdec_attention")) x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x
def transformer_decoder(decoder_input, encoder_output, residual_fn, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, name="decoder"): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor residual_fn: a function from (layer_input, layer_output) -> combined_output decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ x = decoder_input # Summaries don't work in multi-problem setting yet. summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): x = residual_fn( x, common_attention.multihead_attention( x, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, summaries=summaries, name="decoder_self_attention")) x = residual_fn( x, common_attention.multihead_attention( x, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, summaries=summaries, name="encdec_attention")) x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, train, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias( targets_segment, targets_segment, lower_triangular=True) inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) attention_dropout = hparams.attention_dropout * tf.to_float(train) qv = common_attention.multihead_attention(targets_timed, None, target_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_dropout, name="self_attention", summaries=False) qv = common_attention.multihead_attention(qv, inputs_encoded, inputs_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_dropout, name="encdec_attention", summaries=False) return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention( targets_timed, inputs_encoded, bias=bias, summaries=False) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, train, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias( targets_segment, targets_segment, lower_triangular=True) inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) attention_dropout = hparams.attention_dropout * tf.to_float(train) qv = common_attention.multihead_attention( targets_timed, None, target_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_dropout, name="self_attention", summaries=False) qv = common_attention.multihead_attention( qv, inputs_encoded, inputs_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_dropout, name="encdec_attention", summaries=False) return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention( targets_timed, inputs_encoded, bias=bias, summaries=False) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")
def alt_transformer_decoder(decoder_input, encoder_output, residual_fn, mask, encoder_decoder_attention_bias, hparams, name="decoder"): """Alternative decoder.""" x = decoder_input # Summaries don't work in multi-problem setting yet. summaries = "problems" not in hparams.values() or len( hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): x_ = common_attention.multihead_attention( x, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, summaries=summaries, name="encdec_attention") x_ = residual_fn(x_, composite_layer(x_, mask, hparams)) x = residual_fn(x, x_) return x
def transformer_encoder(encoder_input, residual_fn, encoder_self_attention_bias, hparams, name="encoder"): """A stack of transformer layers. Args: encoder_input: a Tensor residual_fn: a function from (layer_input, layer_output) -> combined_output encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ x = encoder_input with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): x = residual_fn( x, common_attention.multihead_attention( x, None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="encoder_self_attention")) x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x