def f(x, side_input): """f(x) for reversible layer, self-attention and enc-dec attention.""" decoder_self_attention_bias = side_input[0] encoder_decoder_attention_bias = side_input[1] encoder_output = side_input[2] old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder"): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ x = encoder_input with tf.variable_scope(name): pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover( common_attention.attention_bias_to_padding( encoder_self_attention_bias)) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def g(x): """g(x) for reversible layer, feed-forward layer.""" old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y
def attend(x, source, hparams, name): with tf.variable_scope(name): x = tf.squeeze(x, axis=2) if len(source.get_shape()) > 3: source = tf.squeeze(source, axis=2) source = common_attention.add_timing_signal_1d(source) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), source, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.expand_dims(res, axis=2)
def attention_lm_decoder(decoder_input, decoder_self_attention_bias, hparams, name="decoder"): """A stack of attention_lm layers. Args: decoder_input: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = common_layers.conv_hidden_relu( common_layers.layer_preprocess(x, hparams), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder"): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def ffn(x, hparams, name): with tf.variable_scope(name): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) return common_layers.layer_postprocess(x, y, hparams)