Exemplo n.º 1
0
def attention_lm_moe_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
    pad_remover (expert_utils.PadRemover): an util object to remove padding
  """
    targets_pad_mask = common_attention.embedding_to_padding(targets)
    with tf.name_scope("pad_remover"):
        pad_remover = expert_utils.PadRemover(targets_pad_mask)

    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepended(targets_pad_mask))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                tf.shape(targets)[1]))
    decoder_input = common_layers.shift_left_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias, pad_remover)
Exemplo n.º 2
0
def prepare_decoder(targets, target_space_emb):
    """Prepare decoder."""
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    target_space_emb = tf.reshape(target_space_emb, [1, 1, -1])
    target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1])
    decoder_input = common_layers.shift_left_3d(targets,
                                                pad_value=target_space_emb)
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Exemplo n.º 3
0
def attention_lm_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    decoder_input = common_layers.shift_left_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Exemplo n.º 4
0
def long_answer_prepare_decoder(inputs, targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    inputs: a Tensor.
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
  """
  decoder_input = tf.concat([
      length_embedding(targets, hparams), inputs,
      common_layers.shift_left_3d(targets)
  ], 1)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return decoder_input
Exemplo n.º 5
0
def transformer_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        tf.shape(targets)[1])
  decoder_input = common_layers.shift_left_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
Exemplo n.º 6
0
def transformer_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.
  
    Args:
      targets: a Tensor.
      hparams: run hyperparameters
  
    Returns:
      decoder_input: a Tensor, bottom of decoder stack
      decoder_self_attention_bias: a bias tensor for use in encoder self-attention
    """
    decoder_self_attention_bias = (comm_attn.attention_bias_lower_triangle(
        tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += comm_attn.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_left_3d(targets)
    if hparams.pos == 'timing':
        decoder_input = comm_attn.add_timing_signal_1d(decoder_input)
    # Putting this here since always called immediately after...
    decoder_input = with_dropout(decoder_input, hparams)

    return DecoderState(input=decoder_input,
                        self_attn_bias=decoder_self_attention_bias)