예제 #1
0
  def f(x, side_input):
    """f(x) for reversible layer, self-attention and enc-dec attention."""
    decoder_self_attention_bias = side_input[0]
    encoder_decoder_attention_bias = side_input[1]
    encoder_output = side_input[2]

    old_hid_size = hparams.hidden_size
    hparams.hidden_size = old_hid_size // 2

    with tf.variable_scope("self_attention"):
      y = common_attention.multihead_attention(
          common_layers.layer_preprocess(
              x, hparams), None, decoder_self_attention_bias,
          hparams.attention_key_channels or hparams.hidden_size,
          hparams.attention_value_channels or hparams.hidden_size,
          hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
      y = common_layers.layer_postprocess(x, y, hparams)
      if encoder_output is not None:
        with tf.variable_scope("encdec_attention"):
          y = common_attention.multihead_attention(
              common_layers.layer_preprocess(
                  x, hparams), encoder_output, encoder_decoder_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
          y = common_layers.layer_postprocess(x, y, hparams)
    hparams.hidden_size = old_hid_size
    return y
def transformer_encoder_layers(inputs,
                               num_layers,
                               hparams,
                               attention_type=AttentionType.GLOBAL,
                               self_attention_bias=None,
                               q_padding="VALID",
                               kv_padding="VALID",
                               name="transformer"):
  """Multi layer transformer encoder."""
  x = inputs
  x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout)

  for layer in range(num_layers):
    # attention layers + skip connections
    with tf.variable_scope("%s_layer_%d" % (name, layer)):
      if attention_type == AttentionType.LOCAL_2D:
        y = local_attention_2d(common_layers.layer_preprocess(x, hparams),
                               hparams,
                               attention_type="local_attention_2d")
      elif attention_type == AttentionType.LOCAL_1D:
        y = local_attention_1d(common_layers.layer_preprocess(x, hparams),
                               hparams,
                               attention_type="local_unmasked",
                               q_padding=q_padding, kv_padding=kv_padding)
      elif attention_type == AttentionType.GLOBAL:
        y = full_self_attention(common_layers.layer_preprocess(x, hparams),
                                self_attention_bias, hparams,
                                q_padding=q_padding, kv_padding=kv_padding)
      x = common_layers.layer_postprocess(x, y, hparams)
      # feed-fwd layer + skip connections
      y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams)
      x = common_layers.layer_postprocess(x, y, hparams)
  return common_layers.layer_preprocess(x, hparams)
예제 #3
0
def image_encoder(image_feat,
                  hparams,
                  name="image_encoder",
                  save_weights_to=None,
                  make_image_summary=True):
  """A stack of self attention layers."""

  x = image_feat
  image_hidden_size = hparams.image_hidden_size or hparams.hidden_size
  image_filter_size = hparams.image_filter_size or hparams.filter_size
  with tf.variable_scope(name):
    for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers):
      with tf.variable_scope("layer_%d" % layer):
        with tf.variable_scope("self_attention"):
          y = vqa_layers.multihead_attention(
              common_layers.layer_preprocess(x, hparams),
              None,
              None,
              hparams.attention_key_channels or image_hidden_size,
              hparams.attention_value_channels or image_hidden_size,
              image_hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=hparams.image_self_attention_type,
              save_weights_to=save_weights_to,
              make_image_summary=make_image_summary,
              scale_dotproduct=hparams.scale_dotproduct,
          )
          utils.collect_named_outputs(
              "norms", "image_feat_self_attention_%d"%(layer),
              tf.norm(y, axis=-1))
          x = common_layers.layer_postprocess(x, y, hparams)
          utils.collect_named_outputs(
              "norms", "image_feat_self_attention_postprocess_%d"%(layer),
              tf.norm(x, axis=-1))
        with tf.variable_scope("ffn"):
          y = common_layers.dense_relu_dense(
              common_layers.layer_preprocess(x, hparams),
              image_filter_size,
              image_hidden_size,
              dropout=hparams.relu_dropout,
          )
          utils.collect_named_outputs(
              "norms", "image_feat_ffn_%d"%(layer), tf.norm(y, axis=-1))
          x = common_layers.layer_postprocess(x, y, hparams)
          utils.collect_named_outputs(
              "norms", "image_feat_ffn_postprocess_%d"%(layer),
              tf.norm(x, axis=-1))
    # if normalization is done in layer_preprocess, then it should also be done
    # on the output, since the output can grow very large, being the sum of
    # a whole stack of unnormalized layer outputs.
    return common_layers.layer_preprocess(x, hparams)
예제 #4
0
def residual_block_layer(inputs, hparams):
  """Residual block over inputs.

  Runs a residual block consisting of
    conv: kernel_size x kernel_size
    conv: 1x1
    dropout, add and normalize according to hparams.layer_postprocess_sequence.

  Args:
    inputs: Tensor of shape [batch, height, width, hparams.hidden_size].
    hparams: tf.contrib.training.HParams.

  Returns:
    Tensor of shape [batch, height, width, hparams.hidden_size].
  """
  kernel = (hparams.res_kernel_size, hparams.res_kernel_size)
  x = inputs
  for i in range(hparams.num_res_layers):
    with tf.variable_scope("res_conv_%d" % i):
      # kernel_size x kernel_size conv block
      y = common_layers.conv_block(
          common_layers.layer_norm(x, hparams.hidden_size, name="lnorm"),
          hparams.hidden_size, [((1, 1), kernel)],
          strides=(1, 1),
          padding="SAME",
          name="residual_conv")
      # 1x1 conv block
      y = common_layers.conv_block(
          y,
          hparams.hidden_size, [((1, 1), (1, 1))],
          strides=(1, 1),
          padding="SAME",
          name="residual_dense")
      x = common_layers.layer_postprocess(x, y, hparams)
  return x
def transformer_layers_sharded(dp,
                               ps_devices,
                               inputs,
                               num_layers,
                               hparams,
                               self_attention_bias=None,
                               enc_output=None,
                               attention_type=AttentionType.GLOBAL,
                               name="transformer"):
  """Multi layer transformer, sharded by the data parallelism dp."""
  x = inputs
  extra_loss = tf.constant(0.0)
  moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
  expert_fn = expert_utils.ffn_expert_fn(
      hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size)
  x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
  for layer in range(num_layers):
    with tf.variable_scope("%s_layer_%d" % (name, layer)):
      # self-attention
      if attention_type == AttentionType.LOCAL_2D:
        y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="masked_local_attention_2d"))
      elif attention_type == AttentionType.LOCAL_1D:
        y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams),
                                  hparams,
                                  attention_type="local_mask_right",
                                  q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOCAL:
        y = dp(local_global_attention(
            common_layers.layer_preprocess(x, hparams), self_attention_bias,
            hparams, q_padding="LEFT", kv_padding="LEFT"))
      elif attention_type == AttentionType.GLOBAL:
        self_attention_bias = dp(get_self_attention_bias(x))
        y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams),
                                   self_attention_bias, hparams,
                                   q_padding="LEFT", kv_padding="LEFT"))
      x = common_layers.layer_postprocess(x, y, hparams)
      if enc_output is not None:
        y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams),
                                   enc_output, None, hparams))
        x = dp(common_layers.layer_postprocess, x, y, hparams)
      with tf.variable_scope("ffn"):
        if str(layer) in hparams.moe_layers_decoder.split(","):
          y, loss = expert_utils.distributed_moe(
              dp,
              ps_devices,
              common_layers.layer_preprocess(x, hparams),
              hparams.mode == tf.estimator.ModeKeys.TRAIN,
              input_size=hparams.hidden_size,
              expert_fn=expert_fn,
              num_experts=hparams.moe_num_experts,
              k=hparams.moe_k,
              loss_coef=hparams.moe_loss_coef)
          extra_loss += loss
          x = dp(common_layers.layer_postprocess, x, y, hparams)
        else:
          y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams)
          x = dp(common_layers.layer_postprocess, x, y, hparams)
  return dp(common_layers.layer_preprocess, x, hparams), extra_loss
def transformer_decoder_layers(inputs,
                               encoder_output,
                               bias,
                               num_layers,
                               hparams,
                               attention_type=AttentionType.LOCAL_2D,
                               name="transformer"):
  """Multi layer transformer."""
  x = inputs
  x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout)
  if attention_type == AttentionType.DILATED:
    assert len(hparams.gap_sizes) == num_layers
  for layer in xrange(num_layers):
    with tf.variable_scope("%s_layer_%d" % (name, layer)):
      # self-attention + skip connections
      if attention_type == AttentionType.LOCAL_2D:
        y = local_attention_2d(common_layers.layer_preprocess(x, hparams),
                               hparams,
                               attention_type="masked_local_attention_2d")
      elif attention_type == AttentionType.LOCAL_1D:
        y = local_attention_1d(common_layers.layer_preprocess(x, hparams),
                               bias, hparams,
                               attention_type="local_mask_right",
                               q_padding="LEFT", kv_padding="LEFT")
      elif attention_type == AttentionType.GLOCAL:
        y = local_global_attention(common_layers.layer_preprocess(x, hparams),
                                   bias, hparams,
                                   q_padding="LEFT", kv_padding="LEFT")
      elif attention_type == AttentionType.DILATED:
        y = dilated_attention_1d(common_layers.layer_preprocess(x, hparams),
                                 bias, hparams, q_padding="LEFT",
                                 kv_padding="LEFT",
                                 gap_size=hparams.gap_sizes[layer])
      elif attention_type == AttentionType.GLOBAL:
        y = full_self_attention(common_layers.layer_preprocess(x, hparams),
                                bias, hparams,
                                q_padding="LEFT", kv_padding="LEFT")
      x = common_layers.layer_postprocess(x, y, hparams)
      # enc-dec attention + skip connections
      if encoder_output is not None:
        y = encdec_attention_1d(common_layers.layer_preprocess(x, hparams),
                                encoder_output, hparams)
        x = common_layers.layer_postprocess(x, y, hparams)
      # feed-fwd layers + skip connections
      y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams)
      x = common_layers.layer_postprocess(x, y, hparams)
  return common_layers.layer_preprocess(x, hparams)
예제 #7
0
 def g(x):
   """g(x) for reversible layer, feed-forward layer."""
   old_hid_size = hparams.hidden_size
   hparams.hidden_size = old_hid_size // 2
   with tf.variable_scope("ffn"):
     y = transformer.transformer_ffn_layer(
         common_layers.layer_preprocess(x, hparams), hparams)
     y = common_layers.layer_postprocess(x, y, hparams)
   hparams.hidden_size = old_hid_size
   return y
예제 #8
0
def attend(x, source, hparams, name):
  with tf.variable_scope(name):
    x = tf.squeeze(x, axis=2)
    if len(source.get_shape()) > 3:
      source = tf.squeeze(source, axis=2)
    source = common_attention.add_timing_signal_1d(source)
    y = common_attention.multihead_attention(
        common_layers.layer_preprocess(x, hparams), source, None,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size, hparams.num_heads,
        hparams.attention_dropout)
    res = common_layers.layer_postprocess(x, y, hparams)
    return tf.expand_dims(res, axis=2)
예제 #9
0
def compress_self_attention_layer(x, hparams, name=None):
  """Attend function."""
  with tf.variable_scope(name, default_name="compress_self_attention"):
    x, xshape, _ = cia.maybe_reshape_4d_to_3d(x)
    y = common_attention.multihead_attention(
        common_layers.layer_preprocess(x, hparams),
        None,
        None,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size, hparams.num_heads,
        hparams.attention_dropout)
    res = common_layers.layer_postprocess(x, y, hparams)
    return tf.reshape(res, xshape)
예제 #10
0
def attention_lm_decoder(decoder_input,
                         decoder_self_attention_bias,
                         hparams,
                         name="decoder"):
  """A stack of attention_lm layers.

  Args:
    decoder_input: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string

  Returns:
    y: a Tensors
  """
  x = decoder_input
  with tf.variable_scope(name):
    for layer in xrange(hparams.num_hidden_layers):
      with tf.variable_scope("layer_%d" % layer):
        with tf.variable_scope("self_attention"):
          y = common_attention.multihead_attention(
              common_layers.layer_preprocess(
                  x, hparams), None, decoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
          x = common_layers.layer_postprocess(x, y, hparams)
        with tf.variable_scope("ffn"):
          y = common_layers.conv_hidden_relu(
              common_layers.layer_preprocess(x, hparams),
              hparams.filter_size,
              hparams.hidden_size,
              dropout=hparams.relu_dropout)
          x = common_layers.layer_postprocess(x, y, hparams)
    return common_layers.layer_preprocess(x, hparams)
예제 #11
0
def attend(x, source, hparams, name):
  """Attend function."""
  with tf.variable_scope(name):
    # x = tf.squeeze(x, axis=2)
    x, xshape, _ = cia.maybe_reshape_4d_to_3d(x)
    if len(source.get_shape()) > 3:
      source = tf.squeeze(source, axis=2)
    source = common_attention.add_timing_signal_1d(source)
    y = common_attention.multihead_attention(
        common_layers.layer_preprocess(x, hparams),
        source,
        None,
        hparams.attention_key_channels or hparams.hidden_size,
        hparams.attention_value_channels or hparams.hidden_size,
        hparams.hidden_size, hparams.num_heads,
        hparams.attention_dropout)
    res = common_layers.layer_postprocess(x, y, hparams)
    return tf.reshape(res, xshape)
예제 #12
0
def transformer_encoder_gate(encoder_input,
                             encoder_self_attention_bias,
                             hparams,
                             name="encoder"):
    """A stack of transformer layers.

    Args:
      encoder_input: a Tensor
      encoder_self_attention_bias: bias Tensor for self-attention
         (see common_attention.attention_bias())
      hparams: hyperparameters for model
      name: a string

    Returns:
      y: a Tensors
    """
    x = encoder_input
    with tf.variable_scope(name):
        pad_remover = None
        if hparams.use_pad_remover:
            pad_remover = expert_utils.PadRemover(
                common_attention.attention_bias_to_padding(
                    encoder_self_attention_bias))
        for layer in xrange(hparams.num_encoder_layers or
                                    hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position)
                    x = common_layers.layer_postprocess(x, y, hparams)

                    gate_fiter = tf.get_variable(
                        'gate_layer_%d' % layer,
                        [1, hparams.hidden_size, hparams.hidden_size],
                        tf.float32, initializer=tf.contrib.layers.xavier_initializer())
                    gate_x = tf.tanh(
                        tf.nn.conv1d(x, gate_fiter, 1, 'SAME'))
                    x *= gate_x

                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(
                        common_layers.layer_preprocess(x, hparams), hparams, pad_remover)
                    x = common_layers.layer_postprocess(x, y, hparams)

                    gate_fiter = tf.get_variable(
                        'gate_layer_%d' % layer,
                        [1, hparams.hidden_size, hparams.hidden_size],
                        tf.float32, initializer=tf.contrib.layers.xavier_initializer())
                    gate_x = tf.tanh(
                        tf.nn.conv1d(x, gate_fiter, 1, 'SAME'))
                    x *= gate_x
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
def transformer_decoder_fast(decoder_input,
                             encoder_output,
                             decoder_self_attention_bias,
                             encoder_decoder_attention_bias,
                             hparams,
                             cache=None,
                             name="decoder"):
    """A stack of transformer layers.
  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_position_forward_mask: mask Tensor for position-forward / shape: [1, t, 1]
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string
  Returns:
    y: a Tensors
  """
    x = decoder_input
    with tf.variable_scope(name):
        for layer in range(hparams.num_decoder_layers
                           or hparams.num_hidden_layers):
            layer_name = "layer_%d" % layer
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        cache=layer_cache)
                    x = common_layers.layer_postprocess(x, y, hparams)
                if encoder_output is not None:
                    with tf.variable_scope("encdec_attention"):
                        y = multihead_attention(
                            common_layers.layer_preprocess(x, hparams),
                            encoder_output,
                            encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            cache=layer_cache)
                        x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer.transformer_ffn_layer(
                        common_layers.layer_preprocess(x, hparams), hparams)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #14
0
def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        cache=None,
                        name="decoder",
                        ctxly=None,
                        model_config=None):
    """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string

  Returns:
    y: a Tensors
  """
    x = decoder_input
    contexts = None
    if ctxly is None:
        ctxly = (hparams.num_decoder_layers or hparams.num_hidden_layers) - 1
    print('Use Context layer %s for output' % ctxly)
    with tf.variable_scope(name):
        for layer in xrange(hparams.num_decoder_layers
                            or hparams.num_hidden_layers):
            layer_name = "layer_%d" % layer
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        cache=layer_cache)
                    x = common_layers.layer_postprocess(x, y, hparams)
                if encoder_output is not None:
                    with tf.variable_scope("encdec_attention"):
                        # TODO(llion): Add caching.
                        y = common_attention.multihead_attention(
                            common_layers.layer_preprocess(x, hparams),
                            encoder_output,
                            encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            model_config=model_config,
                            decoder_input=decoder_input)
                        if layer == ctxly:
                            contexts = tf.identity(y)
                        x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(common_layers.layer_preprocess(
                        x, hparams),
                                              hparams,
                                              conv_padding="LEFT")
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams), contexts
예제 #15
0
def image_question_encoder(encoder_inputs,
                           encoder_self_attention_bias,
                           hparams,
                           query=None,
                           name="image_question_encoder",
                           save_weights_to=None,
                           make_image_summary=True):
    """A stack of self attention layers."""
    x = encoder_inputs
    with tf.variable_scope(name):
        for layer in range(hparams.num_encoder_layers
                           or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("self_attention"):
                    y = vqa_layers.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        block_length=hparams.block_length,
                        save_weights_to=save_weights_to,
                        make_image_summary=make_image_summary,
                        scale_dotproduct=hparams.scale_dotproduct,
                    )
                    utils.collect_named_outputs(
                        "norms", "encoder_self_attention_%d" % (layer),
                        tf.norm(y, axis=-1))
                    x = common_layers.layer_postprocess(x, y, hparams)
                    utils.collect_named_outputs(
                        "norms",
                        "encoder_self_attention_postprocess_%d" % (layer),
                        tf.norm(x, axis=-1))
                if query is not None:
                    with tf.variable_scope("encdec_attention"):
                        y = common_attention.multihead_attention(
                            common_layers.layer_preprocess(x, hparams),
                            query,
                            None,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            attention_type=hparams.self_attention_type,
                            block_length=hparams.block_length,
                            save_weights_to=save_weights_to,
                            make_image_summary=make_image_summary,
                            scale_dotproduct=hparams.scale_dotproduct,
                        )
                        utils.collect_named_outputs(
                            "norms", "encoder_decoder_attention_%d" % (layer),
                            tf.norm(y, axis=-1))
                        x = common_layers.layer_postprocess(x, y, hparams)
                        utils.collect_named_outputs(
                            "norms",
                            "encoder_decoder_attention_post_%d" % (layer),
                            tf.norm(x, axis=-1))
                with tf.variable_scope("ffn"):
                    y = common_layers.dense_relu_dense(
                        common_layers.layer_preprocess(x, hparams),
                        hparams.filter_size,
                        hparams.hidden_size,
                        dropout=hparams.relu_dropout,
                    )
                    utils.collect_named_outputs("norms",
                                                "encoder_ffn_%d" % (layer),
                                                tf.norm(y, axis=-1))
                    x = common_layers.layer_postprocess(x, y, hparams)
                    utils.collect_named_outputs(
                        "norms", "encoder_ffn_postprocess_%d" % (layer),
                        tf.norm(x, axis=-1))
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #16
0
def transformer_decoder_layers(inputs,
                               encoder_output,
                               num_layers,
                               hparams,
                               self_attention_bias=None,
                               encoder_decoder_attention_bias=None,
                               attention_type=AttentionType.LOCAL_2D,
                               name="transformer"):
    """Multi layer transformer."""
    x = inputs
    x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout)
    if attention_type == AttentionType.DILATED:
        assert len(hparams.gap_sizes) == num_layers
    for layer in xrange(num_layers):
        with tf.variable_scope("%s_layer_%d" % (name, layer)):
            # self-attention + skip connections
            if attention_type == AttentionType.LOCAL_2D:
                y = local_attention_2d(
                    common_layers.layer_preprocess(x, hparams),
                    hparams,
                    attention_type="masked_local_attention_2d")
            elif attention_type == AttentionType.LOCAL_1D:
                y = local_attention_1d(common_layers.layer_preprocess(
                    x, hparams),
                                       hparams,
                                       attention_type="local_mask_right",
                                       q_padding="LEFT",
                                       kv_padding="LEFT")
            elif attention_type == AttentionType.LOCAL_BLOCK:
                y = local_within_block_attention(
                    common_layers.layer_preprocess(x, hparams),
                    self_attention_bias,
                    hparams,
                    attention_type="local_within_block_mask_right",
                    q_padding="LEFT",
                    kv_padding="LEFT")
            elif attention_type == AttentionType.GLOCAL:
                y = local_global_attention(common_layers.layer_preprocess(
                    x, hparams),
                                           self_attention_bias,
                                           hparams,
                                           q_padding="LEFT",
                                           kv_padding="LEFT")
            elif attention_type == AttentionType.DILATED:
                y = dilated_attention_1d(common_layers.layer_preprocess(
                    x, hparams),
                                         hparams,
                                         q_padding="LEFT",
                                         kv_padding="LEFT",
                                         gap_size=hparams.gap_sizes[layer])
            elif attention_type == AttentionType.GLOBAL:
                y = full_self_attention(common_layers.layer_preprocess(
                    x, hparams),
                                        self_attention_bias,
                                        hparams,
                                        q_padding="LEFT",
                                        kv_padding="LEFT")
            x = common_layers.layer_postprocess(x, y, hparams)
            # enc-dec attention + skip connections
            if encoder_output is not None:
                y = encdec_attention_1d(
                    common_layers.layer_preprocess(x, hparams), encoder_output,
                    encoder_decoder_attention_bias, hparams)
                x = common_layers.layer_postprocess(x, y, hparams)
            # feed-fwd layers + skip connections
            y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams)
            x = common_layers.layer_postprocess(x, y, hparams)
    return common_layers.layer_preprocess(x, hparams)
def evolved_transformer_decoder(decoder_input,
                                encoder_output,
                                decoder_self_attention_bias,
                                encoder_decoder_attention_bias,
                                hparams,
                                cache=None,
                                decode_loop_step=None,
                                name="decoder",
                                nonpadding=None,
                                save_weights_to=None,
                                make_image_summary=True,
                                losses=None):
    """Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details.

  Args:
    decoder_input: a Tensor.
    encoder_output: a Tensor.
    decoder_self_attention_bias: bias Tensor for self-attention (see
      common_attention.attention_bias()).
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias()).
    hparams: hyperparameters for model.
    cache: Not supported.
    decode_loop_step: An integer, step number of the decoding loop. Only used
      for inference on TPU.
    name: a string.
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This is used to mask out
      padding in convolutional layers.  We generally only need this mask for
      "packed" datasets, because for ordinary datasets, no padding is ever
      followed by nonpadding.
    save_weights_to: an optional dictionary to capture attention weights for
      visualization; the weights tensor will be appended there under a string
      key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: Not supported.

  Returns:
    Decoder output tensor.
  """
    del cache, losses

    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))

    with tf.variable_scope(name):
        hidden_state = decoder_input
        layer_cache = None

        for layer in range(hparams.num_decoder_layers
                           or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):

                with tf.variable_scope("16_head_self_attention"):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    # 16 head attention. Hard coding number of heads.
                    left_state = common_attention.multihead_attention(
                        hidden_state,
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        16,  # Heads are hard coded to replicate paper.
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        heads_share_relative_embedding=(
                            hparams.heads_share_relative_embedding),
                        add_relative_to_values=hparams.add_relative_to_values,
                        save_weights_to=save_weights_to,
                        cache=layer_cache,
                        make_image_summary=make_image_summary,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        decode_loop_step=decode_loop_step,
                        vars_3d=hparams.get("attention_variables_3d"),
                        activation_dtype=hparams.get("activation_dtype",
                                                     "float32"),
                        weight_dtype=hparams.get("weight_dtype", "float32"))

                if encoder_output is not None:
                    with tf.variable_scope("first_attend_to_encoder"):
                        right_state = common_attention.multihead_attention(
                            hidden_state,
                            encoder_output,
                            encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            max_relative_position=hparams.
                            max_relative_position,
                            heads_share_relative_embedding=(
                                hparams.heads_share_relative_embedding),
                            add_relative_to_values=hparams.
                            add_relative_to_values,
                            save_weights_to=save_weights_to,
                            cache=layer_cache,
                            make_image_summary=make_image_summary,
                            dropout_broadcast_dims=
                            attention_dropout_broadcast_dims,
                            max_length=hparams.get("max_length"),
                            vars_3d=hparams.get("attention_variables_3d"),
                            activation_dtype=hparams.get(
                                "activation_dtype", "float32"),
                            weight_dtype=hparams.get("weight_dtype",
                                                     "float32"))

                        left_state = tf.nn.dropout(
                            left_state,
                            1 - hparams.layer_prepostprocess_dropout)
                        right_state = tf.nn.dropout(
                            right_state,
                            1 - hparams.layer_prepostprocess_dropout)

                        hidden_state = residual_state + left_state + right_state

                else:
                    hidden_state = common_layers.layer_postprocess(
                        residual_state, left_state, hparams)

                with tf.variable_scope("conv_branches"):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    if nonpadding is not None:
                        # Mask padding from conv layers.
                        mask = tf.tile(tf.expand_dims(nonpadding, 2),
                                       [1, 1, hparams.hidden_size])
                        hidden_state *= mask

                    # Shift inputs so that future tokens cannot be seen.
                    left_state = tf.pad(hidden_state,
                                        paddings=[[0, 0], [10, 0], [0, 0]])
                    left_output_dim = int(hparams.hidden_size * 2)
                    separable_conv_11x1 = tf.layers.SeparableConv1D(
                        left_output_dim,
                        11,
                        padding="VALID",
                        name="separable_conv11x1",
                        activation=tf.nn.relu)
                    left_state = separable_conv_11x1.apply(left_state)
                    left_state = tf.nn.dropout(
                        left_state, 1 - hparams.layer_prepostprocess_dropout)

                    right_state = tf.pad(hidden_state,
                                         paddings=[[0, 0], [6, 0], [0, 0]])
                    right_output_dim = int(hparams.hidden_size / 2)
                    separable_conv_7x1_1 = tf.layers.SeparableConv1D(
                        right_output_dim,
                        7,
                        padding="VALID",
                        name="separable_conv_7x1_1")
                    right_state = separable_conv_7x1_1.apply(right_state)
                    right_state = tf.nn.dropout(
                        right_state, 1 - hparams.layer_prepostprocess_dropout)
                    right_state = tf.pad(
                        right_state, [[0, 0], [0, 0],
                                      [0, left_output_dim - right_output_dim]],
                        constant_values=0)

                    hidden_state = left_state + right_state

                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)
                    if nonpadding is not None:
                        # Mask padding from conv layers.
                        mask = tf.tile(tf.expand_dims(nonpadding, 2),
                                       [1, 1, hparams.hidden_size * 2])
                        hidden_state *= mask

                    hidden_state = tf.pad(hidden_state,
                                          paddings=[[0, 0], [6, 0], [0, 0]])
                    separable_conv_7x1_2 = tf.layers.SeparableConv1D(
                        hparams.hidden_size,
                        7,
                        padding="VALID",
                        name="separable_conv_7x1_2")
                    hidden_state = separable_conv_7x1_2.apply(hidden_state)

                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                with tf.variable_scope("self_attention"):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    hidden_state = common_attention.multihead_attention(
                        hidden_state,
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        heads_share_relative_embedding=(
                            hparams.heads_share_relative_embedding),
                        add_relative_to_values=hparams.add_relative_to_values,
                        save_weights_to=save_weights_to,
                        cache=layer_cache,
                        make_image_summary=make_image_summary,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        decode_loop_step=decode_loop_step,
                        vars_3d=hparams.get("attention_variables_3d"),
                        activation_dtype=hparams.get("activation_dtype",
                                                     "float32"),
                        weight_dtype=hparams.get("weight_dtype", "float32"))
                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                if encoder_output is not None:
                    with tf.variable_scope("second_attend_to_encoder"):
                        residual_state = hidden_state
                        hidden_state = common_layers.layer_preprocess(
                            hidden_state, hparams)

                        hidden_state = common_attention.multihead_attention(
                            hidden_state,
                            encoder_output,
                            encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            max_relative_position=hparams.
                            max_relative_position,
                            heads_share_relative_embedding=(
                                hparams.heads_share_relative_embedding),
                            add_relative_to_values=hparams.
                            add_relative_to_values,
                            save_weights_to=save_weights_to,
                            cache=layer_cache,
                            make_image_summary=make_image_summary,
                            dropout_broadcast_dims=
                            attention_dropout_broadcast_dims,
                            max_length=hparams.get("max_length"),
                            vars_3d=hparams.get("attention_variables_3d"),
                            activation_dtype=hparams.get(
                                "activation_dtype", "float32"),
                            weight_dtype=hparams.get("weight_dtype",
                                                     "float32"))
                        hidden_state = common_layers.layer_postprocess(
                            residual_state, hidden_state, hparams)

                with tf.variable_scope("dense_layers"):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    hidden_state = tf.layers.dense(hidden_state,
                                                   int(hparams.hidden_size *
                                                       4),
                                                   activation=tf.nn.swish)
                    hidden_state = tf.nn.dropout(
                        hidden_state, 1 - hparams.layer_prepostprocess_dropout)

                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    hidden_state = tf.layers.dense(hidden_state,
                                                   hparams.hidden_size)
                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

        return common_layers.layer_preprocess(hidden_state, hparams)
예제 #18
0
def ffn(x, hparams, name):
    with tf.variable_scope(name):
        y = transformer.transformer_ffn_layer(
            common_layers.layer_preprocess(x, hparams), hparams)
        return common_layers.layer_postprocess(x, y, hparams)
예제 #19
0
def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder",
                        nonpadding=None,
                        save_weights_to=None):
    """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This must either be
      passed in, which we do for "packed" datasets, or inferred from
      encoder_self_attention_bias.  The knowledge about padding is used
      for pad_remover(efficiency) and to mask out padding in convoltutional
      layers.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).

  Returns:
    y: a Tensors
  """
    x = encoder_input
    with tf.variable_scope(name):
        if nonpadding is not None:
            padding = 1.0 - nonpadding
        else:
            padding = common_attention.attention_bias_to_padding(
                encoder_self_attention_bias)
            nonpadding = 1.0 - padding
        pad_remover = None
        if hparams.use_pad_remover:
            pad_remover = expert_utils.PadRemover(padding)
        for layer in xrange(hparams.num_encoder_layers
                            or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        save_weights_to=save_weights_to,
                        max_relative_position=hparams.max_relative_position)
                    x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(common_layers.layer_preprocess(
                        x, hparams),
                                              hparams,
                                              pad_remover,
                                              conv_padding="SAME",
                                              nonpadding_mask=nonpadding)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #20
0
def transformer_decoder_gate(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        cache=None,
                        name="decoder"):
  """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string

  Returns:
    y: a Tensors
  """
  x = decoder_input
  with tf.variable_scope(name):
    for layer in xrange(hparams.num_decoder_layers or
                        hparams.num_hidden_layers):
      layer_name = "layer_%d" % layer
      layer_cache = cache[layer_name] if cache is not None else None
      with tf.variable_scope(layer_name):
        with tf.variable_scope("self_attention"):
          y = common_attention.multihead_attention(
              common_layers.layer_preprocess(x, hparams),
              None,
              decoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              max_relative_position=hparams.max_relative_position,
              cache=layer_cache)
          x = common_layers.layer_postprocess(x, y, hparams)

          gate_fiter = tf.get_variable(
              'gate_layer_%d' % layer,
              [1, hparams.hidden_size, hparams.hidden_size],
              tf.float32, initializer=tf.contrib.layers.xavier_initializer())
          gate_x = tf.tanh(
              tf.nn.conv1d(x, gate_fiter, 1, 'SAME'))
          x *= gate_x

        if encoder_output is not None:
          with tf.variable_scope("encdec_attention"):
            # TODO(llion): Add caching.
            y = common_attention.multihead_attention(
                common_layers.layer_preprocess(x, hparams),
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size, hparams.num_heads,
                hparams.attention_dropout)
            x = common_layers.layer_postprocess(x, y, hparams)

            gate_fiter = tf.get_variable(
                'gate_layer_%d' % layer,
                [1, hparams.hidden_size, hparams.hidden_size],
                tf.float32, initializer=tf.contrib.layers.xavier_initializer())
            gate_x = tf.tanh(
                tf.nn.conv1d(x, gate_fiter, 1, 'SAME'))
            x *= gate_x
        with tf.variable_scope("ffn"):
          y = transformer_ffn_layer(
              common_layers.layer_preprocess(x, hparams), hparams)
          x = common_layers.layer_postprocess(x, y, hparams)

          gate_fiter = tf.get_variable(
              'gate_layer_%d' % layer,
              [1, hparams.hidden_size, hparams.hidden_size],
              tf.float32, initializer=tf.contrib.layers.xavier_initializer())
          gate_x = tf.tanh(
              tf.nn.conv1d(x, gate_fiter, 1, 'SAME'))
        x *= gate_x
    # if normalization is done in layer_preprocess, then it shuold also be done
    # on the output, since the output can grow very large, being the sum of
    # a whole stack of unnormalized layer outputs.
    return common_layers.layer_preprocess(x, hparams)
예제 #21
0
def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        cache=None,
                        name="decoder",
                        nonpadding=None,
                        save_weights_to=None,
                        make_image_summary=True):
  """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This is used
      to mask out padding in convoltutional layers.  We generally only
      need this mask for "packed" datasets, because for ordinary datasets,
      no padding is ever followed by nonpadding.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.

  Returns:
    y: a Tensors
  """
  x = decoder_input
  attention_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "attention_dropout_broadcast_dims", "")))
  with tf.variable_scope(name):
    for layer in xrange(hparams.num_decoder_layers or
                        hparams.num_hidden_layers):
      layer_name = "layer_%d" % layer
      layer_cache = cache[layer_name] if cache is not None else None
      with tf.variable_scope(layer_name):
        with tf.variable_scope("self_attention"):
          y = common_attention.multihead_attention(
              common_layers.layer_preprocess(x, hparams),
              None,
              decoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              save_weights_to=save_weights_to,
              max_relative_position=hparams.max_relative_position,
              cache=layer_cache,
              make_image_summary=make_image_summary,
              dropout_broadcast_dims=attention_dropout_broadcast_dims)
          x = common_layers.layer_postprocess(x, y, hparams)
        if encoder_output is not None:
          with tf.variable_scope("encdec_attention"):
            # TODO(llion): Add caching.
            y = common_attention.multihead_attention(
                common_layers.layer_preprocess(x, hparams),
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                save_weights_to=save_weights_to,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=attention_dropout_broadcast_dims)
            x = common_layers.layer_postprocess(x, y, hparams)
        with tf.variable_scope("ffn"):
          y = transformer_ffn_layer(
              common_layers.layer_preprocess(x, hparams), hparams,
              conv_padding="LEFT", nonpadding_mask=nonpadding)
          x = common_layers.layer_postprocess(x, y, hparams)
    # if normalization is done in layer_preprocess, then it shuold also be done
    # on the output, since the output can grow very large, being the sum of
    # a whole stack of unnormalized layer outputs.
    return common_layers.layer_preprocess(x, hparams)
예제 #22
0
def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        cache=None,
                        name="decoder",
                        terminal_decoder_bias=None,
                        nonterminal_decoder_bias=None,
                        nonpadding=None,
                        pos_signals=None):
    """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This is used
      to mask out padding in convoltutional layers.  We generally only
      need this mask for "packed" datasets, because for ordinary datasets,
      no padding is ever followed by nonpadding.

  Returns:
    y: a Tensors
  """
    x = decoder_input
    sequence_length = usr_utils.get_length_from_nonpadding(nonpadding)
    with tf.variable_scope(name):
        for layer in xrange(hparams.num_decoder_layers
                            or hparams.num_hidden_layers):
            layer_name = "layer_%d" % layer
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                for layer_type in _iter_layer_types(
                        hparams.decoder_layer_types, layer):
                    if layer_type == "self_att":
                        with tf.variable_scope("self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                None,
                                None,
                                decoder_self_attention_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.
                                decoder_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position,
                                cache=layer_cache)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "nt_self_att":
                        with tf.variable_scope("nonterminal_self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                None,
                                None,
                                nonterminal_decoder_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.
                                decoder_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position,
                                cache=layer_cache)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "t_self_att":
                        with tf.variable_scope("terminal_self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                None,
                                None,
                                terminal_decoder_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.
                                decoder_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position,
                                cache=layer_cache)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "parent_ffn":
                        with tf.variable_scope("parent_ffn"):
                            parent_pointers = tf.cast(
                                pos_signals["parent_timing"], tf.int32)
                            parent_x = usr_utils.gather_2d(x, parent_pointers)
                            x = tf.concat([x, parent_x], axis=2)
                            x = transformer_ffn_layer(x,
                                                      hparams,
                                                      conv_padding="LEFT")
                    elif layer_type == "rnn":
                        with tf.variable_scope("recurrent"):
                            y = transformer_rnn_layer(
                                common_layers.layer_preprocess(x, hparams),
                                sequence_length, hparams)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "enc_att" and encoder_output is not None:
                        with tf.variable_scope("encdec_attention"):
                            # TODO(llion): Add caching.
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                encoder_output, None,
                                encoder_decoder_attention_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size, hparams.hidden_size,
                                hparams.num_heads, hparams.attention_dropout)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    else:
                        tf.logging.warn(
                            "Ignoring '%s' in decoder_layer_types" %
                            layer_type)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(common_layers.layer_preprocess(
                        x, hparams),
                                              hparams,
                                              conv_padding="LEFT",
                                              nonpadding_mask=nonpadding)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #23
0
def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder",
                        nonpadding=None):
    """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This must either be
      passed in, which we do for "packed" datasets, or inferred from
      encoder_self_attention_bias.  The knowledge about padding is used
      for pad_remover(efficiency) and to mask out padding in convoltutional
      layers.

  Returns:
    y: a Tensors
  """
    x = encoder_input
    with tf.variable_scope(name):
        if nonpadding is not None:
            padding = 1.0 - nonpadding
        else:
            padding = common_attention.attention_bias_to_padding(
                encoder_self_attention_bias)
            nonpadding = 1.0 - padding
        pad_remover = None
        if hparams.use_pad_remover:
            pad_remover = expert_utils.PadRemover(padding)
        sequence_length = usr_utils.get_length_from_nonpadding(nonpadding)
        for layer in xrange(hparams.num_encoder_layers
                            or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                for layer_type in _iter_layer_types(
                        hparams.encoder_layer_types, layer):
                    if layer_type == "self_att":
                        with tf.variable_scope("self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                None,
                                None,
                                encoder_self_attention_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.
                                encoder_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "rnn":
                        with tf.variable_scope("recurrent"):
                            y = transformer_rnn_layer(
                                common_layers.layer_preprocess(x, hparams),
                                sequence_length, hparams)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "birnn":
                        with tf.variable_scope("recurrent"):
                            y = transformer_rnn_layer(
                                common_layers.layer_preprocess(x, hparams),
                                sequence_length,
                                hparams,
                                bidirectional=True)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    else:
                        tf.logging.warn(
                            "Ignoring '%s' in encoder_layer_types" %
                            layer_type)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(common_layers.layer_preprocess(
                        x, hparams),
                                              hparams,
                                              pad_remover,
                                              conv_padding="SAME",
                                              nonpadding_mask=nonpadding)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
def transformer_decoder_fast_aan(decoder_input,
                                 encoder_output,
                                 decoder_position_forward_mask,
                                 encoder_decoder_attention_bias,
                                 hparams,
                                 cache=None,
                                 name="decoder"):
    """A stack of transformer layers.
  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_position_forward_mask: mask Tensor for position-forward / shape: [1, t, 1]
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string
  Returns:
    y: a Tensors
  """
    x = decoder_input
    with tf.variable_scope(name):
        for layer in range(hparams.num_decoder_layers
                           or hparams.num_hidden_layers):
            layer_name = "layer_%d" % layer
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                with tf.variable_scope("position_forward"):
                    if layer_cache:
                        given_inputs_new = layer_cache['given_inputs'] + x
                        x_fwd = given_inputs_new * decoder_position_forward_mask
                        layer_cache['given_inputs'] = given_inputs_new + x
                    else:
                        x_fwd = tf.cumsum(
                            x, axis=1) * decoder_position_forward_mask
                    # FFN activation
                    y = transformer.transformer_ffn_layer(
                        common_layers.layer_preprocess(x_fwd, hparams),
                        hparams)

                    # Gating layer
                    z = tf.layers.dense(tf.concat([x, y], axis=-1),
                                        hparams.hidden_size * 2,
                                        name="z_project")
                    i, f = tf.split(z, 2, axis=-1)
                    y = tf.sigmoid(i) * x + tf.sigmoid(f) * y
                    x = common_layers.layer_postprocess(x, y, hparams)

                if encoder_output is not None:
                    with tf.variable_scope("encdec_attention"):
                        y = multihead_attention(
                            common_layers.layer_preprocess(x, hparams),
                            encoder_output,
                            encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            cache=layer_cache)
                        x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer.transformer_ffn_layer(
                        common_layers.layer_preprocess(x, hparams), hparams)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #25
0
def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        cache=None,
                        name="decoder",
                        return_attention_weight=False):
  """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string

  Returns:
    y: a Tensors
  """
  x = decoder_input
  additional_outputs = {}
  if return_attention_weight:
    additional_outputs["attention_weight"] = []
  with tf.variable_scope(name):
    for layer in xrange(hparams.num_decoder_layers or
                        hparams.num_hidden_layers):
      layer_name = "layer_%d" % layer
      layer_cache = cache[layer_name] if cache is not None else None
      with tf.variable_scope(layer_name):
        with tf.variable_scope("self_attention"):
          y = common_attention.multihead_attention(
              common_layers.layer_preprocess(x, hparams),
              None,
              decoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              max_relative_position=hparams.max_relative_position,
              cache=layer_cache)
          x = common_layers.layer_postprocess(x, y, hparams)
        if encoder_output is not None:
          with tf.variable_scope("encdec_attention"):
            # TODO(llion): Add caching.
            y = common_attention.multihead_attention(
                common_layers.layer_preprocess(
                    x, hparams), encoder_output, encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size, hparams.num_heads,
                hparams.attention_dropout,
                return_attention_weight=return_attention_weight)
            if return_attention_weight:
              y, attention_weight = y
              additional_outputs["attention_weight"].append(attention_weight)
            x = common_layers.layer_postprocess(x, y, hparams)
        with tf.variable_scope("ffn"):
          y = transformer_ffn_layer(
              common_layers.layer_preprocess(x, hparams), hparams)
          x = common_layers.layer_postprocess(x, y, hparams)
    # if normalization is done in layer_preprocess, then it shuold also be done
    # on the output, since the output can grow very large, being the sum of
    # a whole stack of unnormalized layer outputs.
    if additional_outputs != {}:
      return common_layers.layer_preprocess(x, hparams), additional_outputs
    return common_layers.layer_preprocess(x, hparams)
예제 #26
0
        def symbols_to_logits_fn(ids, ids_tag, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets_method(targets, i)

            ids_tag = ids_tag[:, -1:]
            targets_tag = tf.expand_dims(tf.expand_dims(ids_tag, axis=2),
                                         axis=3)
            targets_tag = preprocess_targets_tag_method(targets_tag, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope('body'):
                with tf.variable_scope('edit_ops_layer'):
                    with tf.variable_scope('ffn'):
                        x = targets
                        preproc = lambda z: common_layers.layer_preprocess(
                            z, hparams, layer_collection=None)
                        layer_inputs = [
                            tf.concat(preproc(x), axis=0),
                            tf.concat(preproc(targets_tag), axis=0),
                        ]
                        y = transformer_layers.transformer_ffn_layer(
                            tf.concat(layer_inputs, axis=2),
                            hparams,
                            conv_padding='LEFT',
                            nonpadding_mask=features_to_nonpadding(
                                features, 'targets'),
                            losses=None,
                            cache=cache,
                            decode_loop_step=None,
                            layer_collection=None,
                        )
                        targets = common_layers.layer_postprocess(
                            x, y, hparams)

                if hparams.middle_prediction:
                    num_decoder_layers = (hparams.num_decoder_layers
                                          or hparams.num_hidden_layers)
                    hparams.num_decoder_layers = int(
                        num_decoder_layers /
                        hparams.middle_prediction_layer_factor)

                body_outputs = dp(
                    self.decode,
                    targets,
                    cache.get('encoder_output'),
                    cache.get('encoder_decoder_attention_bias'),
                    bias,
                    hparams,
                    cache,
                    nonpadding=features_to_nonpadding(features, 'targets'),
                )[0]

                body_outputs, logits_tag = dp(
                    self._prediction_cascade_predict,
                    hparams,
                    features_to_nonpadding(features, 'targets'),
                    cache.get('encoder_decoder_attention_bias'),
                    cache.get('encoder_output'),
                    body_outputs,
                )
                logits_tag = logits_tag[0]['targets_error_tag']
                if hparams.middle_prediction:
                    with tf.variable_scope('after_prediction'):
                        body_outputs = dp(
                            self.decode,
                            targets + body_outputs[0],
                            cache.get('encoder_output'),
                            cache.get('encoder_decoder_attention_bias'),
                            bias,
                            hparams,
                            cache,
                            nonpadding=features_to_nonpadding(
                                features, 'targets'),
                        )

            update_decoder_attention_history(cache)

            modality_name = hparams.name.get(
                'targets',
                modalities.get_name(target_modality))(hparams,
                                                      target_vocab_size)
            with tf.variable_scope('targets/' + modality_name):
                top = hparams.top.get('targets',
                                      modalities.get_top(target_modality))
                logits = dp(top, body_outputs, None, hparams,
                            target_vocab_size)[0]

            ret = tf.squeeze(logits, axis=[1, 2])
            if partial_targets is not None:
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size,
                        0.0,
                        -1e9,
                    )

                ret = tf.cond(
                    tf.less(i, partial_targets_length),
                    forced_logits,
                    lambda: ret,
                )
            logits_tag = tf.squeeze(logits_tag, axis=[1])
            return ret, logits_tag, cache
예제 #27
0
def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        cache=None,
                        name="decoder",
                        nonpadding=None,
                        save_weights_to=None):
    """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This is used
      to mask out padding in convoltutional layers.  We generally only
      need this mask for "packed" datasets, because for ordinary datasets,
      no padding is ever followed by nonpadding.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).

  Returns:
    y: a Tensors
  """
    x = decoder_input
    with tf.variable_scope(name):
        for layer in xrange(hparams.num_decoder_layers
                            or hparams.num_hidden_layers):
            layer_name = "layer_%d" % layer
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        save_weights_to=save_weights_to,
                        max_relative_position=hparams.max_relative_position,
                        cache=layer_cache)
                    x = common_layers.layer_postprocess(x, y, hparams)
                if encoder_output is not None:
                    with tf.variable_scope("encdec_attention"):
                        # TODO(llion): Add caching.
                        y = common_attention.multihead_attention(
                            common_layers.layer_preprocess(x, hparams),
                            encoder_output,
                            encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            save_weights_to=save_weights_to)
                        x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(common_layers.layer_preprocess(
                        x, hparams),
                                              hparams,
                                              conv_padding="LEFT",
                                              nonpadding_mask=nonpadding)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #28
0
def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder",
                        nonpadding=None,
                        save_weights_to=None,
                        make_image_summary=True,
                        losses=None,
                        attn_bias_for_padding=None):
    """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This must either be
      passed in, which we do for "packed" datasets, or inferred from
      encoder_self_attention_bias.  The knowledge about padding is used
      for pad_remover(efficiency) and to mask out padding in convolutional
      layers.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: optional list onto which to append extra training losses
    attn_bias_for_padding: Padded attention bias in case a unidirectional
      encoder is being used where future attention is masked.

  Returns:
    y: a Tensors
  """
    x = encoder_input
    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))
    mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS,
                                 value=hparams.num_encoder_layers
                                 or hparams.num_hidden_layers)
    mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT,
                                 value=hparams.attention_dropout)
    mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DENSE,
                                 value={
                                     "use_bias": "false",
                                     "num_heads": hparams.num_heads,
                                     "hidden_size": hparams.hidden_size
                                 })

    with tf.variable_scope(name):
        if nonpadding is not None:
            padding = 1.0 - nonpadding
        else:
            attention_bias = encoder_self_attention_bias
            if attn_bias_for_padding is not None:
                attention_bias = attn_bias_for_padding
            padding = common_attention.attention_bias_to_padding(
                attention_bias)
            nonpadding = 1.0 - padding
        pad_remover = None
        if hparams.use_pad_remover and not common_layers.is_xla_compiled():
            pad_remover = expert_utils.PadRemover(padding)
        for layer in range(hparams.num_encoder_layers
                           or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        heads_share_relative_embedding=(
                            hparams.heads_share_relative_embedding),
                        add_relative_to_values=hparams.add_relative_to_values,
                        save_weights_to=save_weights_to,
                        make_image_summary=make_image_summary,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        vars_3d=hparams.get("attention_variables_3d"),
                        activation_dtype=hparams.get("activation_dtype",
                                                     "float32"),
                        weight_dtype=hparams.get("weight_dtype", "float32"))
                    x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(common_layers.layer_preprocess(
                        x, hparams),
                                              hparams,
                                              pad_remover,
                                              conv_padding="SAME",
                                              nonpadding_mask=nonpadding,
                                              losses=losses)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        mlperf_log.transformer_print(
            key=mlperf_log.MODEL_HP_NORM,
            value={"hidden_size": hparams.hidden_size})
        return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_biases,
                        encoder_decoder_attention_biases,
                        hparams,
                        cache=None,
                        decode_loop_step=None,
                        name="decoder",
                        nonpadding=None,
                        save_weights_to=None,
                        make_image_summary=True,
                        losses=None):
  """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_biases: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_biases: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    decode_loop_step: An integer, step number of the decoding loop.
        Only used for inference on TPU.
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This is used
      to mask out padding in convolutional layers.  We generally only
      need this mask for "packed" datasets, because for ordinary datasets,
      no padding is ever followed by nonpadding.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: optional list onto which to append extra training losses

  Returns:
    y: a Tensors
  """
  x = decoder_input
  attention_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "attention_dropout_broadcast_dims", "")))
  with tf.variable_scope(name):
    for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers):
      layer_name = "layer_%d" % layer
      layer_cache = cache[layer_name] if cache is not None else None
      with tf.variable_scope(layer_name):
        for context_type in hparams.transformer_context_types:
          with tf.variable_scope("self_attention_%s" % context_type):
            y = common_attention.multihead_attention(
                common_layers.layer_preprocess(x, hparams),
                None,
                decoder_self_attention_biases[context_type],
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                attention_type=hparams.self_attention_type,
                max_relative_position=hparams.max_relative_position,
                heads_share_relative_embedding=(
                    hparams.heads_share_relative_embedding),
                add_relative_to_values=hparams.add_relative_to_values,
                save_weights_to=save_weights_to,
                cache=layer_cache,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
                max_length=hparams.get("max_length"),
                decode_loop_step=decode_loop_step,
                vars_3d=hparams.get("attention_variables_3d"))
            x = common_layers.layer_postprocess(x, y, hparams)
        if encoder_output is not None:
          with tf.variable_scope("encdec_attention"):
            y = common_attention.multihead_attention(
                common_layers.layer_preprocess(x, hparams),
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                max_relative_position=hparams.max_relative_position,
                heads_share_relative_embedding=(
                    hparams.heads_share_relative_embedding),
                add_relative_to_values=hparams.add_relative_to_values,
                save_weights_to=save_weights_to,
                cache=layer_cache,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
                max_length=hparams.get("max_length"),
                vars_3d=hparams.get("attention_variables_3d"))
            x = common_layers.layer_postprocess(x, y, hparams)
        with tf.variable_scope("ffn"):
          y = transformer_ffn_layer(
              common_layers.layer_preprocess(x, hparams),
              hparams,
              conv_padding="LEFT",
              nonpadding_mask=nonpadding,
              losses=losses,
              cache=layer_cache,
              decode_loop_step=decode_loop_step)
          x = common_layers.layer_postprocess(x, y, hparams)
    # if normalization is done in layer_preprocess, then it should also be done
    # on the output, since the output can grow very large, being the sum of
    # a whole stack of unnormalized layer outputs.
    return common_layers.layer_preprocess(x, hparams)
def evolved_transformer_encoder(encoder_input,
                                encoder_self_attention_bias,
                                hparams,
                                name="encoder",
                                nonpadding=None,
                                save_weights_to=None,
                                make_image_summary=True,
                                losses=None,
                                attn_bias_for_padding=None):
    """Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details.

  Note: Pad remover is not supported.

  Args:
    encoder_input: a Tensor.
    encoder_self_attention_bias: bias Tensor for self-attention (see
      common_attention.attention_bias()).
    hparams: hyperparameters for model.
    name: a string.
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This must either be passed in,
      which we do for "packed" datasets, or inferred from
      encoder_self_attention_bias.  The knowledge about padding is used for
      pad_remover(efficiency) and to mask out padding in convolutional layers.
    save_weights_to: an optional dictionary to capture attention weights for
      visualization; the weights tensor will be appended there under a string
      key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: Not used.
    attn_bias_for_padding: Padded attention bias in case a unidirectional
      encoder is being used where future attention is masked.

  Returns:
    Tensor encoder output.
  """
    del losses

    hidden_state = encoder_input
    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))

    with tf.variable_scope(name):
        if nonpadding is not None:
            padding = 1.0 - nonpadding
        else:
            attention_bias = encoder_self_attention_bias
            if attn_bias_for_padding is not None:
                attention_bias = attn_bias_for_padding
            padding = common_attention.attention_bias_to_padding(
                attention_bias)
            nonpadding = 1.0 - padding

        for layer in range(hparams.num_encoder_layers
                           or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):

                with tf.variable_scope("gated_linear_unit"):

                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    values = common_layers.layers().Dense(
                        hparams.hidden_size)(hidden_state)
                    gates = common_layers.layers().Dense(
                        hparams.hidden_size,
                        activation=tf.nn.sigmoid)(hidden_state)
                    hidden_state = values * gates

                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                with tf.variable_scope("conv_branches"):

                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)
                    # Mask padding from conv layers.
                    mask = tf.tile(tf.expand_dims(nonpadding, 2),
                                   [1, 1, hparams.hidden_size])
                    hidden_state *= mask

                    left_output_dim = int(hparams.hidden_size * 4)
                    left_state = common_layers.layers().Dense(
                        left_output_dim, activation=tf.nn.relu)(hidden_state)
                    left_state = tf.nn.dropout(
                        left_state, 1 - hparams.layer_prepostprocess_dropout)

                    right_output_dim = int(hparams.hidden_size / 2)
                    right_state = common_layers.layers().Conv1D(
                        right_output_dim,
                        3,
                        padding="SAME",
                        name="standard_conv_3x1",
                        activation=tf.nn.relu)(hidden_state)
                    right_state = tf.nn.dropout(
                        right_state, 1 - hparams.layer_prepostprocess_dropout)

                    right_state = tf.pad(
                        right_state, [[0, 0], [0, 0],
                                      [0, left_output_dim - right_output_dim]],
                        constant_values=0)
                    hidden_state = left_state + right_state

                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)
                    # Mask padding from conv layer.
                    mask = tf.tile(tf.expand_dims(nonpadding, 2),
                                   [1, 1, left_output_dim])
                    hidden_state *= mask

                    separable_conv_9x1 = common_layers.layers(
                    ).SeparableConv1D(right_output_dim,
                                      9,
                                      padding="SAME",
                                      name="separable_conv_9x1")
                    hidden_state = separable_conv_9x1(hidden_state)
                    hidden_state = tf.pad(
                        hidden_state,
                        [[0, 0], [0, 0],
                         [0, hparams.hidden_size - right_output_dim]],
                        constant_values=0)

                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                with tf.variable_scope("self_attention"):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    hidden_state = common_attention.multihead_attention(
                        hidden_state,
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        heads_share_relative_embedding=(
                            hparams.heads_share_relative_embedding),
                        add_relative_to_values=hparams.add_relative_to_values,
                        save_weights_to=save_weights_to,
                        make_image_summary=make_image_summary,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        vars_3d=hparams.get("attention_variables_3d"),
                        activation_dtype=hparams.get("activation_dtype",
                                                     "float32"),
                        weight_dtype=hparams.get("weight_dtype", "float32"))

                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                with tf.variable_scope("dense_layers"):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    hidden_state = common_layers.layers().Dense(
                        int(hparams.hidden_size * 4),
                        activation=tf.nn.relu)(hidden_state)
                    hidden_state = tf.nn.dropout(
                        hidden_state, 1 - hparams.layer_prepostprocess_dropout)

                    hidden_state = common_layers.layers().Dense(
                        hparams.hidden_size)(hidden_state)
                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

        # If normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(hidden_state, hparams)
예제 #31
0
def transformer_encoder(encoder_input,
                        raw_inputs,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder"):
    """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string

  Returns:
    y: a Tensors
  """
    x = encoder_input
    with tf.variable_scope(name):
        raw_encoder_input = tf.squeeze(raw_inputs, axis=[-2, -1])
        sequence_length = usr_utils.get_length_from_raw(
            raw_encoder_input)  # Used for RNNs
        pos_signals = generate_positional_signals(raw_encoder_input, hparams)
        pos_embeddings = generate_positional_embeddings(
            pos_signals, hparams.encoder_pos, hparams)
        attention_pos_embeddings = generate_positional_embeddings(
            pos_signals, hparams.encoder_attention_pos, hparams)
        if "sum" in hparams.pos_integration:
            x = x + pos_embeddings
        elif "ffn" in hparams.pos_integration:
            with tf.variable_scope("pos_ffn"):
                x = tf.concat([x, pos_embeddings], axis=2)
                x = transformer_ffn_layer(x, hparams)
        pad_remover = None
        if hparams.use_pad_remover:
            pad_remover = expert_utils.PadRemover(
                common_attention.attention_bias_to_padding(
                    encoder_self_attention_bias))
        for layer in xrange(hparams.num_encoder_layers
                            or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                for layer_type in _iter_layer_types(
                        hparams.encoder_layer_types, layer):
                    if layer_type == "self_att":
                        with tf.variable_scope("self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                None,
                                None,
                                encoder_self_attention_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.
                                encoder_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "rnn":
                        with tf.variable_scope("recurrent"):
                            y = transformer_rnn_layer(
                                common_layers.layer_preprocess(x, hparams),
                                sequence_length, hparams)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "birnn":
                        with tf.variable_scope("recurrent"):
                            y = transformer_rnn_layer(
                                common_layers.layer_preprocess(x, hparams),
                                sequence_length,
                                hparams,
                                bidirectional=True)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "pos_self_att" and attention_pos_embeddings is not None:
                        with tf.variable_scope("pos_self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                attention_pos_embeddings,  # Query
                                attention_pos_embeddings,  # Key
                                common_layers.layer_preprocess(
                                    x, hparams),  # Value
                                encoder_self_attention_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.pos_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    else:
                        tf.logging.warn(
                            "Ignoring '%s' in encoder_layer_types" %
                            layer_type)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(
                        common_layers.layer_preprocess(x, hparams), hparams,
                        pad_remover)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #32
0
def transformer_layers_sharded(dp,
                               ps_devices,
                               inputs,
                               num_layers,
                               hparams,
                               self_attention_bias=None,
                               enc_output=None,
                               attention_type=AttentionType.GLOBAL,
                               name="transformer"):
    """Multi layer transformer, sharded by the data parallelism dp."""
    x = inputs
    extra_loss = tf.constant(0.0)
    moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
    expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size,
                                           moe_hidden_sizes,
                                           hparams.hidden_size)
    x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout)
    for layer in xrange(num_layers):
        with tf.variable_scope("%s_layer_%d" % (name, layer)):
            # self-attention
            if attention_type == AttentionType.LOCAL_2D:
                y = dp(
                    local_attention_2d(
                        common_layers.layer_preprocess(x, hparams),
                        hparams,
                        attention_type="masked_local_attention_2d"))
            elif attention_type == AttentionType.LOCAL_1D:
                y = dp(
                    local_attention_1d(common_layers.layer_preprocess(
                        x, hparams),
                                       hparams,
                                       attention_type="local_mask_right",
                                       q_padding="LEFT",
                                       kv_padding="LEFT"))
            elif attention_type == AttentionType.GLOCAL:
                y = dp(
                    local_global_attention(common_layers.layer_preprocess(
                        x, hparams),
                                           self_attention_bias,
                                           hparams,
                                           q_padding="LEFT",
                                           kv_padding="LEFT"))
            elif attention_type == AttentionType.GLOBAL:
                self_attention_bias = dp(get_self_attention_bias(x))
                y = dp(
                    full_self_attention(common_layers.layer_preprocess(
                        x, hparams),
                                        self_attention_bias,
                                        hparams,
                                        q_padding="LEFT",
                                        kv_padding="LEFT"))
            x = common_layers.layer_postprocess(x, y, hparams)
            if enc_output is not None:
                y = dp(
                    encdec_attention_1d(
                        common_layers.layer_preprocess(x, hparams), enc_output,
                        None, hparams))
                x = dp(common_layers.layer_postprocess, x, y, hparams)
            with tf.variable_scope("ffn"):
                if str(layer) in hparams.moe_layers_decoder.split(","):
                    y, loss = expert_utils.distributed_moe(
                        dp,
                        ps_devices,
                        common_layers.layer_preprocess(x, hparams),
                        hparams.mode == tf.estimator.ModeKeys.TRAIN,
                        input_size=hparams.hidden_size,
                        expert_fn=expert_fn,
                        num_experts=hparams.moe_num_experts,
                        k=hparams.moe_k,
                        loss_coef=hparams.moe_loss_coef)
                    extra_loss += loss
                    x = dp(common_layers.layer_postprocess, x, y, hparams)
                else:
                    y = dp(ffn_layer,
                           common_layers.layer_preprocess(x, hparams), hparams)
                    x = dp(common_layers.layer_postprocess, x, y, hparams)
    return dp(common_layers.layer_preprocess, x, hparams), extra_loss
  def body(self, features):
    assert self._hparams.block_size > 0
    assert not common_layers.is_xla_compiled()

    hparams = copy.copy(self._hparams)
    targets = features["targets"]
    inputs = features["inputs"]
    if not (tf.get_variable_scope().reuse or
            hparams.mode == tf.estimator.ModeKeys.PREDICT):
      tf.summary.image("inputs", inputs, max_outputs=1)
      tf.summary.image("targets", targets, max_outputs=1)

    encoder_input = cia.prepare_encoder(inputs, hparams)
    encoder_output = cia.transformer_encoder_layers(
        encoder_input,
        hparams.num_encoder_layers,
        hparams,
        attention_type=hparams.enc_attention_type,
        name="encoder")
    decoder_input, rows, cols = cia.prepare_decoder(
        targets, hparams)
    decoder_output = cia.transformer_decoder_layers(
        decoder_input,
        encoder_output,
        hparams.num_decoder_layers,
        hparams,
        attention_type=hparams.dec_attention_type,
        name="decoder")

    assert not isinstance(decoder_output, tuple)
    assert len(decoder_output.shape) == 4

    relu_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(self._hparams, "relu_dropout_broadcast_dims", "")))

    with tf.variable_scope("block_size_%d" % self._hparams.block_size):
      tf.logging.info("Using block_size %d", self._hparams.block_size)
      block_output = common_layers.dense_relu_dense(
          decoder_output,
          self._hparams.block_size * self._hparams.filter_size,
          self._hparams.block_size * self._hparams.hidden_size,
          dropout=self._hparams.relu_dropout,
          dropout_broadcast_dims=relu_dropout_broadcast_dims)

    batch_size, rows, cols = common_layers.shape_list(decoder_output)[:3]
    decoder_output = tf.reshape(decoder_output, [
        batch_size,
        rows,
        cols,
        1,
        self._hparams.hidden_size
    ])
    block_output = tf.reshape(block_output, [
        batch_size,
        rows,
        cols,
        self._hparams.block_size,
        self._hparams.hidden_size
    ])

    block_output = common_layers.layer_postprocess(
        decoder_output, block_output, self._hparams)

    return block_output
예제 #34
0
def decoder(
    decoder_input,
    encoder_output,
    decoder_self_attention_bias,
    encoder_decoder_attention_bias,
    hparams,
    name="decoder",
    save_weights_to=None,
    make_image_summary=True,
):
    """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.

  Returns:
    y: a Tensors
  """
    x = decoder_input
    with tf.variable_scope(name):
        for layer in range(hparams.num_decoder_layers
                           or hparams.num_hidden_layers):
            layer_name = "layer_%d" % layer
            with tf.variable_scope(layer_name):
                with tf.variable_scope("self_attention"):
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(x, hparams),
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        save_weights_to=save_weights_to,
                        make_image_summary=make_image_summary,
                    )
                    utils.collect_named_outputs(
                        "norms", "decoder_self_attention_%d" % (layer),
                        tf.norm(y, axis=-1))
                    x = common_layers.layer_postprocess(x, y, hparams)
                    utils.collect_named_outputs(
                        "norms", "decoder_self_attention_post_%d" % (layer),
                        tf.norm(x, axis=-1))
                if encoder_output is not None:
                    with tf.variable_scope("encdec_attention"):
                        y = common_attention.multihead_attention(
                            common_layers.layer_preprocess(x, hparams),
                            encoder_output,
                            encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            save_weights_to=save_weights_to,
                            make_image_summary=make_image_summary,
                        )
                        utils.collect_named_outputs(
                            "norms", "decoder_encoder_attention_%d" % (layer),
                            tf.norm(y, axis=-1))
                        x = common_layers.layer_postprocess(x, y, hparams)
                        utils.collect_named_outputs(
                            "norms",
                            "decoder_encoder_attention_post_%d" % (layer),
                            tf.norm(x, axis=-1))
                with tf.variable_scope("ffn"):
                    y = common_layers.dense_relu_dense(
                        common_layers.layer_preprocess(x, hparams),
                        hparams.filter_size,
                        hparams.hidden_size,
                        dropout=hparams.relu_dropout,
                    )
                    utils.collect_named_outputs("norms",
                                                "decoder_ffn_%d" % (layer),
                                                tf.norm(y, axis=-1))
                    x = common_layers.layer_postprocess(x, y, hparams)
                    utils.collect_named_outputs(
                        "norms", "decoder_ffn_post_%d" % (layer),
                        tf.norm(x, axis=-1))
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #35
0
def transformer_decoder(decoder_input,
                        encoder_output,
                        raw_targets,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        cache=None,
                        name="decoder"):
    """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    name: a string

  Returns:
    y: a Tensors
  """
    x = decoder_input
    with tf.variable_scope(name):
        sequence_length = usr_utils.get_length_from_raw(
            tf.squeeze(raw_targets, axis=[-2, -1]))  # Used for RNNs
        sequence_length = sequence_length + 1  # Because of shifting
        raw_decoder_input = common_layers.shift_right(raw_targets)
        terminal_decoder_bias, nonterminal_decoder_bias = _get_t_nt_bias(
            raw_decoder_input, hparams, decoder_self_attention_bias)
        raw_decoder_input = tf.squeeze(raw_decoder_input, axis=[-2, -1])
        pos_signals = generate_positional_signals(raw_decoder_input, hparams,
                                                  terminal_decoder_bias,
                                                  nonterminal_decoder_bias)
        pos_embeddings = generate_positional_embeddings(
            pos_signals, hparams.decoder_pos, hparams)
        attention_pos_embeddings = generate_positional_embeddings(
            pos_signals, hparams.decoder_attention_pos, hparams)
        if "sum" in hparams.pos_integration:
            x = x + pos_embeddings
        elif "ffn" in hparams.pos_integration:
            with tf.variable_scope("pos_ffn"):
                x = tf.concat([x, pos_embeddings], axis=2)
                x = transformer_ffn_layer(x, hparams)
        for layer in xrange(hparams.num_decoder_layers
                            or hparams.num_hidden_layers):
            layer_name = "layer_%d" % layer
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):
                for layer_type in _iter_layer_types(
                        hparams.decoder_layer_types, layer):
                    if layer_type == "self_att":
                        with tf.variable_scope("self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                None,
                                None,
                                decoder_self_attention_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.
                                decoder_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position,
                                cache=layer_cache)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "nt_self_att":
                        with tf.variable_scope("nonterminal_self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                None,
                                None,
                                nonterminal_decoder_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.
                                decoder_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position,
                                cache=layer_cache)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "t_self_att":
                        with tf.variable_scope("terminal_self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                None,
                                None,
                                terminal_decoder_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.
                                decoder_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position,
                                cache=layer_cache)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "rnn":
                        with tf.variable_scope("recurrent"):
                            y = transformer_rnn_layer(
                                common_layers.layer_preprocess(x, hparams),
                                sequence_length, hparams)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "pos_self_att" and attention_pos_embeddings is not None:
                        with tf.variable_scope("pos_self_attention"):
                            y = model_helper.multihead_attention_qkv(
                                attention_pos_embeddings,  # Query
                                attention_pos_embeddings,  # Key
                                common_layers.layer_preprocess(
                                    x, hparams),  # Value
                                decoder_self_attention_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size,
                                hparams.hidden_size,
                                hparams.num_heads,
                                hparams.attention_dropout,
                                attention_type=hparams.pos_self_attention_type,
                                attention_order=hparams.attention_order,
                                max_relative_position=hparams.
                                max_relative_position)
                        x = common_layers.layer_postprocess(x, y, hparams)
                    elif layer_type == "enc_att" and encoder_output is not None:
                        with tf.variable_scope("encdec_attention"):
                            # TODO(llion): Add caching.
                            y = model_helper.multihead_attention_qkv(
                                common_layers.layer_preprocess(x, hparams),
                                encoder_output, None,
                                encoder_decoder_attention_bias,
                                hparams.attention_key_channels
                                or hparams.hidden_size,
                                hparams.attention_value_channels
                                or hparams.hidden_size, hparams.hidden_size,
                                hparams.num_heads, hparams.attention_dropout)
                            x = common_layers.layer_postprocess(x, y, hparams)
                    else:
                        tf.logging.warn(
                            "Ignoring '%s' in decoder_layer_types" %
                            layer_type)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(
                        common_layers.layer_preprocess(x, hparams), hparams)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it shuold also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #36
0
def evolved_transformer_decoder(decoder_input,
                                encoder_output,
                                decoder_self_attention_bias,
                                encoder_decoder_attention_bias,
                                hparams,
                                cache=None,
                                decode_loop_step=None,
                                name="decoder",
                                nonpadding=None,
                                save_weights_to=None,
                                make_image_summary=True,
                                losses=None):
    """Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details.

  Args:
    decoder_input: a Tensor.
    encoder_output: a Tensor.
    decoder_self_attention_bias: bias Tensor for self-attention (see
      common_attention.attention_bias()).
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias()).
    hparams: hyperparameters for model.
    cache: dict, containing tensors which are the results of previous
      layers, used for fast decoding.
    decode_loop_step: An integer, step number of the decoding loop. Only used
      for inference on TPU.
    name: a string.
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This is used to mask out
      padding in convolutional layers.  We generally only need this mask for
      "packed" datasets, because for ordinary datasets, no padding is ever
      followed by nonpadding.
    save_weights_to: an optional dictionary to capture attention weights for
      visualization; the weights tensor will be appended there under a string
      key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: Not supported.

  Returns:
    Decoder output tensor.
  """
    del losses

    num_trainable_top_decoder_layers = hparams.get(
        "num_trainable_top_decoder_layers", -1)  # -1 means train all weights.

    if num_trainable_top_decoder_layers >= 0:
        encoder_output = tf.stop_gradient(encoder_output)

    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))

    with tf.variable_scope(name):
        hidden_state = decoder_input

        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
        for layer in range(num_layers):
            if num_trainable_top_decoder_layers == num_layers - layer:
                hidden_state = tf.stop_gradient(hidden_state)
            layer_name = "layer_%d" % layer
            layer_cache = cache[layer_name] if cache is not None else None
            with tf.variable_scope(layer_name):

                with tf.variable_scope(_SIXTEEN_HEAD_ATTENTION_NAME):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    attention_cache = layer_cache[
                        _SIXTEEN_HEAD_ATTENTION_NAME] if layer_cache is not None else None
                    left_state = common_attention.multihead_attention(
                        hidden_state,
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        _capped_double_heads(hparams.num_heads),
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        heads_share_relative_embedding=(
                            hparams.heads_share_relative_embedding),
                        add_relative_to_values=hparams.add_relative_to_values,
                        save_weights_to=save_weights_to,
                        cache=attention_cache,
                        make_image_summary=make_image_summary,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        decode_loop_step=decode_loop_step,
                        vars_3d=hparams.get("attention_variables_3d"),
                        activation_dtype=hparams.get("activation_dtype",
                                                     "float32"),
                        weight_dtype=hparams.get("weight_dtype", "float32"))

                if encoder_output is not None:
                    with tf.variable_scope(_FIRST_ATTEND_TO_ENCODER_NAME):
                        attention_cache = (
                            layer_cache[_FIRST_ATTEND_TO_ENCODER_NAME]
                            if layer_cache is not None else None)
                        right_state = common_attention.multihead_attention(
                            hidden_state,
                            encoder_output,
                            encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            max_relative_position=hparams.
                            max_relative_position,
                            heads_share_relative_embedding=(
                                hparams.heads_share_relative_embedding),
                            add_relative_to_values=hparams.
                            add_relative_to_values,
                            save_weights_to=save_weights_to,
                            cache=attention_cache,
                            make_image_summary=make_image_summary,
                            dropout_broadcast_dims=
                            attention_dropout_broadcast_dims,
                            max_length=hparams.get("max_length"),
                            vars_3d=hparams.get("attention_variables_3d"),
                            activation_dtype=hparams.get(
                                "activation_dtype", "float32"),
                            weight_dtype=hparams.get("weight_dtype",
                                                     "float32"))

                        left_state = tf.nn.dropout(
                            left_state,
                            1 - hparams.layer_prepostprocess_dropout)
                        right_state = tf.nn.dropout(
                            right_state,
                            1 - hparams.layer_prepostprocess_dropout)

                        hidden_state = residual_state + left_state + right_state

                else:
                    hidden_state = common_layers.layer_postprocess(
                        residual_state, left_state, hparams)

                with tf.variable_scope(_CONV_BRANCHES_NAME):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    if nonpadding is not None:
                        # Mask padding from conv layers.
                        mask = tf.tile(tf.expand_dims(nonpadding, 2),
                                       [1, 1, hparams.hidden_size])
                        hidden_state *= mask

                    if layer_cache:
                        if decode_loop_step is None:
                            hidden_state = layer_cache[
                                _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.concat(
                                    [
                                        layer_cache[
                                            _CONV_BRANCHES_FIRST_LAYER_NAME],
                                        hidden_state
                                    ],
                                    axis=1)[:,
                                            -1 * _DECODER_LEFT_CONV_PADDING -
                                            1:, :]
                            left_state = hidden_state
                            right_state = hidden_state[:,
                                                       _DECODER_LEFT_CONV_PADDING
                                                       -
                                                       _DECODER_RIGHT_CONV_PADDING:, :]

                        else:
                            # Inplace update is required for inference on TPU.
                            # Inplace_ops only supports inplace_update on the first dimension.
                            tmp = tf.transpose(
                                layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME],
                                perm=[1, 0, 2])
                            tmp = tf.expand_dims(tmp, axis=1)
                            tmp = inplace_ops.alias_inplace_update(
                                tmp,
                                decode_loop_step * tf.shape(hidden_state)[1] +
                                _DECODER_LEFT_CONV_PADDING,
                                tf.transpose(hidden_state, perm=[1, 0, 2]))
                            tmp = tf.squeeze(tmp, axis=1)
                            hidden_state = layer_cache[
                                _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.transpose(
                                    tmp, perm=[1, 0, 2])

                            batch_size = hidden_state.shape.as_list()[0]
                            left_state = tf.slice(
                                hidden_state, [0, decode_loop_step, 0], [
                                    batch_size, _DECODER_LEFT_CONV_PADDING + 1,
                                    hparams.hidden_size
                                ])
                            right_state = tf.slice(hidden_state, [
                                0,
                                decode_loop_step + _DECODER_LEFT_CONV_PADDING -
                                _DECODER_RIGHT_CONV_PADDING, 0
                            ], [
                                batch_size, _DECODER_RIGHT_CONV_PADDING + 1,
                                hparams.hidden_size
                            ])

                    else:  # No caching.
                        left_state = tf.pad(
                            hidden_state,
                            paddings=[[0, 0], [_DECODER_LEFT_CONV_PADDING, 0],
                                      [0, 0]])
                        right_state = tf.pad(
                            hidden_state,
                            paddings=[[0, 0], [_DECODER_RIGHT_CONV_PADDING, 0],
                                      [0, 0]])

                    left_output_dim = int(hparams.hidden_size * 2)
                    separable_conv_11x1 = tf.layers.SeparableConv1D(
                        left_output_dim,
                        11,
                        padding="VALID",
                        name="separable_conv11x1",
                        activation=tf.nn.relu)
                    left_state = separable_conv_11x1.apply(left_state)
                    left_state = tf.nn.dropout(
                        left_state, 1 - hparams.layer_prepostprocess_dropout)

                    right_output_dim = int(hparams.hidden_size / 2)
                    separable_conv_7x1_1 = tf.layers.SeparableConv1D(
                        right_output_dim,
                        7,
                        padding="VALID",
                        name="separable_conv_7x1_1")
                    right_state = separable_conv_7x1_1.apply(right_state)
                    right_state = tf.nn.dropout(
                        right_state, 1 - hparams.layer_prepostprocess_dropout)
                    right_state = tf.pad(
                        right_state, [[0, 0], [0, 0],
                                      [0, left_output_dim - right_output_dim]],
                        constant_values=0)

                    hidden_state = left_state + right_state

                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)
                    if nonpadding is not None:
                        # Mask padding from conv layers.
                        mask = tf.tile(tf.expand_dims(nonpadding, 2),
                                       [1, 1, hparams.hidden_size * 2])
                        hidden_state *= mask

                    if layer_cache:
                        if decode_loop_step is None:
                            hidden_state = layer_cache[
                                _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.concat(
                                    [
                                        layer_cache[
                                            _CONV_BRANCHES_SECOND_LAYER_NAME],
                                        hidden_state
                                    ],
                                    axis=1)[:,
                                            -1 * _DECODER_FINAL_CONV_PADDING -
                                            1:, :]

                        else:
                            # Inplace update is required for inference on TPU.
                            # Inplace_ops only supports inplace_update on the first dimension.
                            tmp = tf.transpose(
                                layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME],
                                perm=[1, 0, 2])
                            tmp = tf.expand_dims(tmp, axis=1)
                            tmp = inplace_ops.alias_inplace_update(
                                tmp, (decode_loop_step +
                                      _DECODER_FINAL_CONV_PADDING) *
                                tf.shape(hidden_state)[1],
                                tf.transpose(hidden_state, perm=[1, 0, 2]))
                            tmp = tf.squeeze(tmp, axis=1)
                            hidden_state = layer_cache[
                                _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.transpose(
                                    tmp, perm=[1, 0, 2])

                            batch_size = hidden_state.shape.as_list()[0]
                            hidden_state = tf.slice(
                                hidden_state, [0, decode_loop_step, 0], [
                                    batch_size, _DECODER_FINAL_CONV_PADDING +
                                    1, hparams.hidden_size * 2
                                ])
                    else:
                        hidden_state = tf.pad(
                            hidden_state,
                            paddings=[[0, 0], [_DECODER_FINAL_CONV_PADDING, 0],
                                      [0, 0]])

                    separable_conv_7x1_2 = tf.layers.SeparableConv1D(
                        hparams.hidden_size,
                        7,
                        padding="VALID",
                        name="separable_conv_7x1_2")
                    hidden_state = separable_conv_7x1_2.apply(hidden_state)

                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                with tf.variable_scope(_VANILLA_ATTENTION_NAME):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    attention_cache = layer_cache[
                        _VANILLA_ATTENTION_NAME] if layer_cache is not None else None
                    hidden_state = common_attention.multihead_attention(
                        hidden_state,
                        None,
                        decoder_self_attention_bias,
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        max_relative_position=hparams.max_relative_position,
                        heads_share_relative_embedding=(
                            hparams.heads_share_relative_embedding),
                        add_relative_to_values=hparams.add_relative_to_values,
                        save_weights_to=save_weights_to,
                        cache=attention_cache,
                        make_image_summary=make_image_summary,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        decode_loop_step=decode_loop_step,
                        vars_3d=hparams.get("attention_variables_3d"),
                        activation_dtype=hparams.get("activation_dtype",
                                                     "float32"),
                        weight_dtype=hparams.get("weight_dtype", "float32"))
                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

                if encoder_output is not None:
                    with tf.variable_scope(_SECOND_ATTEND_TO_ENCODER_NAME):
                        residual_state = hidden_state
                        hidden_state = common_layers.layer_preprocess(
                            hidden_state, hparams)

                        attention_cache = (
                            layer_cache[_SECOND_ATTEND_TO_ENCODER_NAME]
                            if layer_cache is not None else None)
                        hidden_state = common_attention.multihead_attention(
                            hidden_state,
                            encoder_output,
                            encoder_decoder_attention_bias,
                            hparams.attention_key_channels
                            or hparams.hidden_size,
                            hparams.attention_value_channels
                            or hparams.hidden_size,
                            hparams.hidden_size,
                            hparams.num_heads,
                            hparams.attention_dropout,
                            max_relative_position=hparams.
                            max_relative_position,
                            heads_share_relative_embedding=(
                                hparams.heads_share_relative_embedding),
                            add_relative_to_values=hparams.
                            add_relative_to_values,
                            save_weights_to=save_weights_to,
                            cache=attention_cache,
                            make_image_summary=make_image_summary,
                            dropout_broadcast_dims=
                            attention_dropout_broadcast_dims,
                            max_length=hparams.get("max_length"),
                            vars_3d=hparams.get("attention_variables_3d"),
                            activation_dtype=hparams.get(
                                "activation_dtype", "float32"),
                            weight_dtype=hparams.get("weight_dtype",
                                                     "float32"))
                        hidden_state = common_layers.layer_postprocess(
                            residual_state, hidden_state, hparams)

                with tf.variable_scope("dense_layers"):
                    residual_state = hidden_state
                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    hidden_state = tf.layers.dense(hidden_state,
                                                   int(hparams.hidden_size *
                                                       4),
                                                   activation=tf.nn.swish)
                    hidden_state = tf.nn.dropout(
                        hidden_state, 1 - hparams.layer_prepostprocess_dropout)

                    hidden_state = common_layers.layer_preprocess(
                        hidden_state, hparams)

                    hidden_state = tf.layers.dense(hidden_state,
                                                   hparams.hidden_size)
                    hidden_state = common_layers.layer_postprocess(
                        residual_state, hidden_state, hparams)

        decoder_output = common_layers.layer_preprocess(hidden_state, hparams)
        if num_trainable_top_decoder_layers == 0:
            decoder_output = tf.stop_gradient(decoder_output)
        return decoder_output
예제 #37
0
def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder",
                        nonpadding=None,
                        save_weights_to=None,
                        make_image_summary=True,
                        losses=None):
    """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This must either be
      passed in, which we do for "packed" datasets, or inferred from
      encoder_self_attention_bias.  The knowledge about padding is used
      for pad_remover(efficiency) and to mask out padding in convolutional
      layers.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: optional list onto which to append extra training losses

  Returns:
    y: a Tensors
  """
    x = encoder_input
    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))
    with tf.variable_scope(name):
        if nonpadding is not None:
            padding = 1.0 - nonpadding
        else:
            padding = common_attention.attention_bias_to_padding(
                encoder_self_attention_bias)
            nonpadding = 1.0 - padding
        pad_remover = None
        if hparams.use_pad_remover and not common_layers.is_on_tpu():
            pad_remover = expert_utils.PadRemover(padding)
        for layer in range(hparams.num_encoder_layers
                           or hparams.num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer):
                with tf.variable_scope("self_attention"):
                    # sg: imdb comments
                    y = common_attention.multihead_attention(
                        common_layers.layer_preprocess(
                            x, hparams),  # added layer norm
                        None,
                        encoder_self_attention_bias,
                        hparams.attention_key_channels
                        or hparams.hidden_size,  # 128
                        hparams.attention_value_channels
                        or hparams.hidden_size,  # 128
                        hparams.hidden_size,  # 128
                        hparams.num_heads,  # 4
                        hparams.attention_dropout,  # 0.1
                        attention_type=hparams.
                        self_attention_type,  # 'dot_product'
                        save_weights_to=save_weights_to,
                        max_relative_position=hparams.
                        max_relative_position,  # 0
                        make_image_summary=make_image_summary,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"))  # 256
                    x = common_layers.layer_postprocess(x, y, hparams)
                with tf.variable_scope("ffn"):
                    y = transformer_ffn_layer(common_layers.layer_preprocess(
                        x, hparams),
                                              hparams,
                                              pad_remover,
                                              conv_padding="SAME",
                                              nonpadding_mask=nonpadding,
                                              losses=losses)
                    x = common_layers.layer_postprocess(x, y, hparams)
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(x, hparams)
예제 #38
0
def decoder(decoder_input,
            encoder_output,
            decoder_self_attention_bias,
            encoder_decoder_attention_bias,
            hparams,
            name="decoder",
            save_weights_to=None,
            make_image_summary=True,):
  """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention
      (see common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.

  Returns:
    y: a Tensors
  """
  x = decoder_input
  with tf.variable_scope(name):
    for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers):
      layer_name = "layer_%d" % layer
      with tf.variable_scope(layer_name):
        with tf.variable_scope("self_attention"):
          y = common_attention.multihead_attention(
              common_layers.layer_preprocess(x, hparams),
              None,
              decoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              save_weights_to=save_weights_to,
              make_image_summary=make_image_summary,
              )
          utils.collect_named_outputs("norms",
                                      "decoder_self_attention_%d"%(layer),
                                      tf.norm(y, axis=-1))
          x = common_layers.layer_postprocess(x, y, hparams)
          utils.collect_named_outputs("norms",
                                      "decoder_self_attention_post_%d"%(layer),
                                      tf.norm(x, axis=-1))
        if encoder_output is not None:
          with tf.variable_scope("encdec_attention"):
            y = common_attention.multihead_attention(
                common_layers.layer_preprocess(x, hparams),
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                save_weights_to=save_weights_to,
                make_image_summary=make_image_summary,
                )
            utils.collect_named_outputs(
                "norms",
                "decoder_encoder_attention_%d"%(layer),
                tf.norm(y, axis=-1))
            x = common_layers.layer_postprocess(x, y, hparams)
            utils.collect_named_outputs(
                "norms",
                "decoder_encoder_attention_post_%d"%(layer),
                tf.norm(x, axis=-1))
        with tf.variable_scope("ffn"):
          y = common_layers.dense_relu_dense(
              common_layers.layer_preprocess(x, hparams),
              hparams.filter_size,
              hparams.hidden_size,
              dropout=hparams.relu_dropout,
          )
          utils.collect_named_outputs("norms", "decoder_ffn_%d"%(layer),
                                      tf.norm(y, axis=-1))
          x = common_layers.layer_postprocess(x, y, hparams)
          utils.collect_named_outputs("norms", "decoder_ffn_post_%d"%(layer),
                                      tf.norm(x, axis=-1))
    # if normalization is done in layer_preprocess, then it should also be done
    # on the output, since the output can grow very large, being the sum of
    # a whole stack of unnormalized layer outputs.
    return common_layers.layer_preprocess(x, hparams)
예제 #39
0
def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder",
                        nonpadding=None,
                        save_weights_to=None,
                        make_image_summary=True):
  """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This must either be
      passed in, which we do for "packed" datasets, or inferred from
      encoder_self_attention_bias.  The knowledge about padding is used
      for pad_remover(efficiency) and to mask out padding in convoltutional
      layers.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.

  Returns:
    y: a Tensors
  """
  x = encoder_input
  attention_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "attention_dropout_broadcast_dims", "")))
  with tf.variable_scope(name):
    if nonpadding is not None:
      padding = 1.0 - nonpadding
    else:
      padding = common_attention.attention_bias_to_padding(
          encoder_self_attention_bias)
      nonpadding = 1.0 - padding
    pad_remover = None
    if hparams.use_pad_remover and not common_layers.is_on_tpu():
      pad_remover = expert_utils.PadRemover(padding)
    for layer in xrange(hparams.num_encoder_layers or
                        hparams.num_hidden_layers):
      with tf.variable_scope("layer_%d" % layer):
        with tf.variable_scope("self_attention"):
          y = common_attention.multihead_attention(
              common_layers.layer_preprocess(x, hparams),
              None,
              encoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              save_weights_to=save_weights_to,
              max_relative_position=hparams.max_relative_position,
              make_image_summary=make_image_summary,
              dropout_broadcast_dims=attention_dropout_broadcast_dims)
          x = common_layers.layer_postprocess(x, y, hparams)
        with tf.variable_scope("ffn"):
          y = transformer_ffn_layer(
              common_layers.layer_preprocess(x, hparams), hparams, pad_remover,
              conv_padding="SAME", nonpadding_mask=nonpadding)
          x = common_layers.layer_postprocess(x, y, hparams)
    # if normalization is done in layer_preprocess, then it shuold also be done
    # on the output, since the output can grow very large, being the sum of
    # a whole stack of unnormalized layer outputs.
    return common_layers.layer_preprocess(x, hparams)
예제 #40
0
def invertible_transformer_decoder_attention_unit(
        x,
        hparams,
        encoder_output,
        decoder_self_attention_bias,
        encoder_decoder_attention_bias,
        attention_dropout_broadcast_dims,
        save_weights_to=None,
        make_image_summary=True,
        split_index=0):
    """Applies multihead attention function which is parametrised for decoding.
  Args:
    x: input (decoder input)
    hparams: model hyper-parameters
    encoder_output: Encoder representation. [batch_size, input_length,
      hidden_dim]
    decoder_self_attention_bias: Bias and mask weights for decoder
      self-attention. [batch_size, decoder_length]
    encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder
      attention. [batch_size, input_length]
    attention_dropout_broadcast_dims: Fpr noise broadcasting in the dropout
      layers to save memory during training
    save_weights_to: an optional dictionary to capture attention weights for
      visualization; the weights tensor will be appended there under a string
      key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
  Returns:
    The output tensor
  """

    ##################
    ## CHANGE START ##
    ##################

    with tf.variable_scope("self_attention"):

        # ERROR: Output is 128 instead of 64

        x_splits = tf.split(x, num_or_size_splits=2, axis=2)

        y = common_attention.multihead_attention(
            common_layers.layer_preprocess(x_splits[split_index], hparams),
            None,
            decoder_self_attention_bias,
            hparams.attention_key_channels or hparams.hidden_size,
            hparams.attention_value_channels or hparams.hidden_size,
            hparams.hidden_size,
            hparams.num_heads,
            hparams.attention_dropout,
            attention_type=hparams.self_attention_type,
            save_weights_to=save_weights_to,
            max_relative_position=hparams.max_relative_position,
            cache=None,
            make_image_summary=make_image_summary,
            dropout_broadcast_dims=attention_dropout_broadcast_dims,
            hard_attention_k=hparams.hard_attention_k)

        x_splits[1 - split_index] = common_layers.layer_postprocess(
            x_splits[1 - split_index], y, hparams)

    if encoder_output is not None:
        with tf.variable_scope("encdec_attention"):
            y = common_attention.multihead_attention(
                common_layers.layer_preprocess(x_splits[split_index], hparams),
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                save_weights_to=save_weights_to,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
                hard_attention_k=hparams.hard_attention_k)

            x_splits[1 - split_index] = common_layers.layer_postprocess(
                x_splits[1 - split_index], y, hparams)

    x = tf.concat(x_splits, axis=2)

    ##################
    ##  CHANGE END  ##
    ##################

    return x
예제 #41
0
def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        hparams,
                        name="encoder",
                        nonpadding=None,
                        save_weights_to=None,
                        make_image_summary=True,
                        losses=None):
  """A stack of transformer layers.

  Args:
    encoder_input: a Tensor
    encoder_self_attention_bias: bias Tensor for self-attention
       (see common_attention.attention_bias())
    hparams: hyperparameters for model
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This must either be
      passed in, which we do for "packed" datasets, or inferred from
      encoder_self_attention_bias.  The knowledge about padding is used
      for pad_remover(efficiency) and to mask out padding in convolutional
      layers.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: optional list onto which to append extra training losses

  Returns:
    y: a Tensors
  """
  x = encoder_input
  attention_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "attention_dropout_broadcast_dims", "")))
  mlperf_log.transformer_print(
      key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS,
      value=hparams.num_encoder_layers or hparams.num_hidden_layers)
  mlperf_log.transformer_print(
      key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT,
      value=hparams.attention_dropout)
  mlperf_log.transformer_print(
      key=mlperf_log.MODEL_HP_ATTENTION_DENSE,
      value={
          "use_bias": "false",
          "num_heads": hparams.num_heads,
          "hidden_size": hparams.hidden_size
      })

  with tf.variable_scope(name):
    if nonpadding is not None:
      padding = 1.0 - nonpadding
    else:
      padding = common_attention.attention_bias_to_padding(
          encoder_self_attention_bias)
      nonpadding = 1.0 - padding
    pad_remover = None
    if hparams.use_pad_remover and not common_layers.is_xla_compiled():
      pad_remover = expert_utils.PadRemover(padding)
    for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers):
      with tf.variable_scope("layer_%d" % layer):
        with tf.variable_scope("self_attention"):
          y = common_attention.multihead_attention(
              common_layers.layer_preprocess(x, hparams),
              None,
              encoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              max_relative_position=hparams.max_relative_position,
              heads_share_relative_embedding=(
                  hparams.heads_share_relative_embedding),
              add_relative_to_values=hparams.add_relative_to_values,
              save_weights_to=save_weights_to,
              make_image_summary=make_image_summary,
              dropout_broadcast_dims=attention_dropout_broadcast_dims,
              max_length=hparams.get("max_length"),
              vars_3d=hparams.get("attention_variables_3d"))
          x = common_layers.layer_postprocess(x, y, hparams)
        with tf.variable_scope("ffn"):
          y = transformer_ffn_layer(
              common_layers.layer_preprocess(x, hparams),
              hparams,
              pad_remover,
              conv_padding="SAME",
              nonpadding_mask=nonpadding,
              losses=losses)
          x = common_layers.layer_postprocess(x, y, hparams)
    # if normalization is done in layer_preprocess, then it should also be done
    # on the output, since the output can grow very large, being the sum of
    # a whole stack of unnormalized layer outputs.
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_NORM,
        value={"hidden_size": hparams.hidden_size})
    return common_layers.layer_preprocess(x, hparams)