Esempio n. 1
0
def transformer_decoder(decoder_input,
                        encoder_output,
                        residual_fn,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        name="decoder"):
    """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    residual_fn: a function from (layer_input, layer_output) -> combined_output
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string

  Returns:
    y: a Tensors
  """
    x = decoder_input
    # Summaries don't work in multi-problem setting yet.
    summaries = "problems" not in hparams.values() or len(
        hparams.problems) == 1
    with tf.variable_scope(name):
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                x = residual_fn(
                    x,
                    common_attention.multihead_attention(
                        x,
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        summaries=summaries,
                        name="decoder_self_attention"))
                x = residual_fn(
                    x,
                    common_attention.multihead_attention(
                        x,
                        encoder_output,
                        encoder_decoder_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        summaries=summaries,
                        name="encdec_attention"))
                x = residual_fn(x, transformer_ffn_layer(x, hparams))
    return x
def transformer_decoder(decoder_input,
                        encoder_output,
                        residual_fn,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        name="decoder"):
  """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    residual_fn: a function from (layer_input, layer_output) -> combined_output
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string

  Returns:
    y: a Tensors
  """
  x = decoder_input
  # Summaries don't work in multi-problem setting yet.
  summaries = "problems" not in hparams.values() or len(hparams.problems) == 1
  with tf.variable_scope(name):
    for layer in xrange(hparams.num_hidden_layers):
      with tf.variable_scope("layer_%d" % layer):
        x = residual_fn(
            x,
            common_attention.multihead_attention(
                x,
                None,
                decoder_self_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                summaries=summaries,
                name="decoder_self_attention"))
        x = residual_fn(
            x,
            common_attention.multihead_attention(
                x,
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                summaries=summaries,
                name="encdec_attention"))
        x = residual_fn(x, transformer_ffn_layer(x, hparams))
  return x
Esempio n. 3
0
def attention(targets_shifted,
              inputs_encoded,
              norm_fn,
              hparams,
              train,
              bias=None):
    """Complete attention layer with preprocessing."""
    separabilities = [hparams.separability, hparams.separability]
    if hparams.separability < 0:
        separabilities = [hparams.separability - 1, hparams.separability]
    targets_timed = common_layers.subseparable_conv_block(
        common_layers.add_timing_signal(targets_shifted),
        hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))],
        normalizer_fn=norm_fn,
        padding="LEFT",
        separabilities=separabilities,
        name="targets_time")
    if hparams.attention_type == "transformer":
        targets_timed = tf.squeeze(targets_timed, 2)
        target_shape = tf.shape(targets_timed)
        targets_segment = tf.zeros([target_shape[0], target_shape[1]])
        target_attention_bias = common_attention.attention_bias(
            targets_segment, targets_segment, lower_triangular=True)
        inputs_attention_bias = tf.zeros([
            tf.shape(inputs_encoded)[0], hparams.num_heads,
            tf.shape(targets_segment)[1],
            tf.shape(inputs_encoded)[1]
        ])

        attention_dropout = hparams.attention_dropout * tf.to_float(train)
        qv = common_attention.multihead_attention(targets_timed,
                                                  None,
                                                  target_attention_bias,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.num_heads,
                                                  attention_dropout,
                                                  name="self_attention",
                                                  summaries=False)
        qv = common_attention.multihead_attention(qv,
                                                  inputs_encoded,
                                                  inputs_attention_bias,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.num_heads,
                                                  attention_dropout,
                                                  name="encdec_attention",
                                                  summaries=False)
        return tf.expand_dims(qv, 2)
    elif hparams.attention_type == "simple":
        targets_with_attention = common_layers.simple_attention(
            targets_timed, inputs_encoded, bias=bias, summaries=False)
        return norm_fn(targets_shifted + targets_with_attention,
                       name="attn_norm")
Esempio n. 4
0
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, train,
              bias=None):
  """Complete attention layer with preprocessing."""
  separabilities = [hparams.separability, hparams.separability]
  if hparams.separability < 0:
    separabilities = [hparams.separability - 1, hparams.separability]
  targets_timed = common_layers.subseparable_conv_block(
      common_layers.add_timing_signal(targets_shifted),
      hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))],
      normalizer_fn=norm_fn,
      padding="LEFT",
      separabilities=separabilities,
      name="targets_time")
  if hparams.attention_type == "transformer":
    targets_timed = tf.squeeze(targets_timed, 2)
    target_shape = tf.shape(targets_timed)
    targets_segment = tf.zeros([target_shape[0], target_shape[1]])
    target_attention_bias = common_attention.attention_bias(
        targets_segment, targets_segment, lower_triangular=True)
    inputs_attention_bias = tf.zeros([
        tf.shape(inputs_encoded)[0], hparams.num_heads,
        tf.shape(targets_segment)[1],
        tf.shape(inputs_encoded)[1]
    ])

    attention_dropout = hparams.attention_dropout * tf.to_float(train)
    qv = common_attention.multihead_attention(
        targets_timed,
        None,
        target_attention_bias,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        attention_dropout,
        name="self_attention",
        summaries=False)
    qv = common_attention.multihead_attention(
        qv,
        inputs_encoded,
        inputs_attention_bias,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        attention_dropout,
        name="encdec_attention",
        summaries=False)
    return tf.expand_dims(qv, 2)
  elif hparams.attention_type == "simple":
    targets_with_attention = common_layers.simple_attention(
        targets_timed, inputs_encoded, bias=bias, summaries=False)
    return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")
Esempio n. 5
0
def alt_transformer_decoder(decoder_input,
                            encoder_output,
                            residual_fn,
                            mask,
                            encoder_decoder_attention_bias,
                            hparams,
                            name="decoder"):
    """Alternative decoder."""
    x = decoder_input

    # Summaries don't work in multi-problem setting yet.
    summaries = "problems" not in hparams.values() or len(
        hparams.problems) == 1
    with tf.variable_scope(name):
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):

                x_ = common_attention.multihead_attention(
                    x,
                    encoder_output,
                    encoder_decoder_attention_bias,
                    hparams.attention_key_channels or hparams.hidden_size,
                    hparams.attention_value_channels or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.num_heads,
                    hparams.attention_dropout,
                    summaries=summaries,
                    name="encdec_attention")

                x_ = residual_fn(x_, composite_layer(x_, mask, hparams))
                x = residual_fn(x, x_)

    return x
Esempio n. 6
0
def transformer_encoder(encoder_input,
                        residual_fn,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder"):
    """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    residual_fn: a function from (layer_input, layer_output) -> combined_output
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string

  Returns:
    y: a Tensors
  """
    x = encoder_input
    with tf.variable_scope(name):
        for layer in xrange(hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                x = residual_fn(
                    x,
                    common_attention.multihead_attention(
                        x,
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        name="encoder_self_attention"))
                x = residual_fn(x, transformer_ffn_layer(x, hparams))
    return x