Example #1
0
    def f(x, side_input):
        """f(x) for reversible layer, self-attention and enc-dec attention."""
        decoder_self_attention_bias = side_input[0]
        encoder_decoder_attention_bias = side_input[1]
        encoder_output = side_input[2]

        old_hid_size = hparams.hidden_size
        hparams.hidden_size = old_hid_size // 2

        with tf.variable_scope("self_attention"):
            y = common_attention.multihead_attention(
                common_layers.layer_preprocess(x, hparams), None,
                decoder_self_attention_bias, hparams.attention_key_channels
                or hparams.hidden_size, hparams.attention_value_channels
                or hparams.hidden_size, hparams.hidden_size, hparams.num_heads,
                hparams.attention_dropout)
            y = common_layers.layer_postprocess(x, y, hparams)
            if encoder_output is not None:
                with tf.variable_scope("encdec_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        encoder_output, encoder_decoder_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size, hparams.hidden_size,
                        hparams.num_heads, hparams.attention_dropout)
                    y = common_layers.layer_postprocess(x, y, hparams)
        hparams.hidden_size = old_hid_size
        return y
Example #2
0
def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder"):
    """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string

  Returns:
    y: a Tensors
  """
    x = encoder_input
    with tf.variable_scope(name):
        pad_remover = None
        if hparams.use_pad_remover:
            pad_remover = expert_utils.PadRemover(
                common_attention.attention_bias_to_padding(
                    encoder_self_attention_bias))
        for layer in xrange(hparams.num_encoder_layers
                            or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position)
                    x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(
                        common_layers.layer_preprocess(x, hparams), hparams,
                        pad_remover)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
Example #3
0
 def g(x):
     """g(x) for reversible layer, feed-forward layer."""
     old_hid_size = hparams.hidden_size
     hparams.hidden_size = old_hid_size // 2
     with tf.variable_scope("ffn"):
         y = transformer.transformer_ffn_layer(
             common_layers.layer_preprocess(x, hparams), hparams)
         y = common_layers.layer_postprocess(x, y, hparams)
     hparams.hidden_size = old_hid_size
     return y
Example #4
0
def attend(x, source, hparams, name):
    with tf.variable_scope(name):
        x = tf.squeeze(x, axis=2)
        if len(source.get_shape()) > 3:
            source = tf.squeeze(source, axis=2)
        source = common_attention.add_timing_signal_1d(source)
        y = common_attention.multihead_attention(
            common_layers.layer_preprocess(x, hparams), source, None,
            hparams.attention_key_channels or hparams.hidden_size,
            hparams.attention_value_channels or hparams.hidden_size,
            hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
        res = common_layers.layer_postprocess(x, y, hparams)
        return tf.expand_dims(res, axis=2)
Example #5
0
def attention_lm_decoder(decoder_input,
                         decoder_self_attention_bias,
                         hparams,
                         name="decoder"):
    """A stack of attention_lm layers.

  Args:
    decoder_input: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string

  Returns:
    y: a Tensors
  """
    x = decoder_input
    with tf.variable_scope(name):
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams), None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size, hparams.hidden_size,
                        hparams.num_heads, hparams.attention_dropout)
                    x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = common_layers.conv_hidden_relu(
                        common_layers.layer_preprocess(x, hparams),
                        hparams.filter_size,
                        hparams.hidden_size,
                        dropout=hparams.relu_dropout)
                    x = common_layers.layer_postprocess(x, y, hparams)
        return common_layers.layer_preprocess(x, hparams)
Example #6
0
def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        cache=None,
                        name="decoder"):
    """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string

  Returns:
    y: a Tensors
  """
    x = decoder_input
    with tf.variable_scope(name):
        for layer in xrange(hparams.num_decoder_layers
                            or hparams.num_hidden_layers):
            layer_name = "layer_%d" % layer
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        cache=layer_cache)
                    x = common_layers.layer_postprocess(x, y, hparams)
                if encoder_output is not None:
                    with tf.variable_scope("encdec_attention"):
                        # TODO(llion): Add caching.
                        y = common_attention.multihead_attention(
                            common_layers.layer_preprocess(x, hparams),
                            encoder_output, encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size, hparams.hidden_size,
                            hparams.num_heads, hparams.attention_dropout)
                        x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(
                        common_layers.layer_preprocess(x, hparams), hparams)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
Example #7
0
def ffn(x, hparams, name):
    with tf.variable_scope(name):
        y = transformer.transformer_ffn_layer(
            common_layers.layer_preprocess(x, hparams), hparams)
        return common_layers.layer_postprocess(x, y, hparams)