def attention_lm_moe_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly biases for diagonal alignments
    pad_remover (expert_utils.PadRemover): an util object to remove padding
  """
    targets_pad_mask = common_attention.embedding_to_padding(targets)
    with tf.name_scope("pad_remover"):
        # Because of the shift_right, the <eos> token will be considered as
        # padding. In practice, it doesn't really matter, due to the triangular
        # mask, this token should never be attended.
        pad_remover = expert_utils.PadRemover(targets_pad_mask)

    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepend_inputs_full_attention(
                targets_pad_mask))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                tf.shape(targets)[1]))
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias, pad_remover)
def attention_lm_moe_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
    pad_remover (expert_utils.PadRemover): an util object to remove padding
  """
  targets_pad_mask = common_attention.embedding_to_padding(targets)
  with tf.name_scope("pad_remover"):
    # Because of the shift_right, the <eos> token will be considered as
    # padding. In practice, it doesn't really matter, due to the triangular
    # mask, this token should never be attended.
    pad_remover = expert_utils.PadRemover(targets_pad_mask)

  if hparams.prepend_mode == "prepend_inputs_full_attention":
    decoder_self_attention_bias = (
        common_attention.attention_bias_prepend_inputs_full_attention(
            targets_pad_mask))
  else:
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias, pad_remover)
Example #3
0
def transformer_prepare_decoder_right(targets, hparams, features=None):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in decoder self-attention
  """
    if hparams.causal_decoder_self_attention:
        # Causal attention.
        if hparams.prepend_mode == "prepend_inputs_full_attention":
            decoder_self_attention_bias = (
                common_attention.attention_bias_prepend_inputs_full_attention(
                    common_attention.embedding_to_padding(targets)))
        else:
            decoder_self_attention_bias = (
                common_attention.attention_bias_local(
                    common_layers.shape_list(targets)[1], 0, -1))
    else:
        # Full attention.
        decoder_padding = common_attention.embedding_to_padding(targets)
        decoder_self_attention_bias = (
            common_attention.attention_bias_ignore_padding(decoder_padding))

    if features and "targets_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        targets_segmentation = features["targets_segmentation"]
        targets_position = features["targets_position"]
        decoder_self_attention_bias += common_attention.attention_bias_same_segment(
            targets_segmentation, targets_segmentation)
    else:
        targets_position = None
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(targets)[1])
    decoder_input = shift_left_3d(targets)
    if hparams.pos == "timing":
        if targets_position is not None:
            decoder_input = common_attention.add_timing_signal_1d_given_position(
                decoder_input, targets_position)
        else:
            decoder_input = common_attention.add_timing_signal_1d(
                decoder_input)
    elif hparams.pos == "emb":
        decoder_input = common_attention.add_positional_embedding(
            decoder_input, hparams.max_length, "targets_positional_embedding",
            targets_position)

    if hparams.activation_dtype == "bfloat16":
        decoder_self_attention_bias = tf.cast(decoder_self_attention_bias,
                                              tf.bfloat16)
    return (decoder_input, decoder_self_attention_bias)
Example #4
0
def transformer_prepare_decoder(targets, hparams, features=None):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  if hparams.prepend_mode == "prepend_inputs_full_attention":
    decoder_self_attention_bias = (
        common_attention.attention_bias_prepend_inputs_full_attention(
            common_attention.embedding_to_padding(targets)))
  else:
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))

  if features and "targets_segmentation" in features:
    # "Packed" dataset - keep the examples from seeing each other.
    targets_segmentation = features["targets_segmentation"]
    targets_position = features["targets_position"]
    decoder_self_attention_bias += common_attention.attention_bias_same_segment(
        targets_segmentation, targets_segmentation)
  else:
    targets_position = None
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    if targets_position is not None:
      decoder_input = common_attention.add_timing_signal_1d_given_position(
          decoder_input, targets_position)
    else:
      decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
Example #5
0
def attention_lm_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly biases for diagonal alignments
  """
    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepend_inputs_full_attention(
                common_attention.embedding_to_padding(targets)))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                common_layers.shape_list(targets)[1]))
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Example #6
0
def attention_lm_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly biases for diagonal alignments
  """
  if hparams.prepend_mode == "prepend_inputs_full_attention":
    decoder_self_attention_bias = (
        common_attention.attention_bias_prepend_inputs_full_attention(
            common_attention.embedding_to_padding(targets)))
  else:
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)