Exemple #1
0
def transformer_prepare_encoder(inputs, target_space, hparams):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(inputs)[1])
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(target_space,
                                               32,
                                               ishape_static[-1],
                                               name="target_space_embedding")
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
    if hparams.pos == "timing":
        encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Exemple #2
0
def attention_lm_moe_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
    pad_remover (expert_utils.PadRemover): an util object to remove padding
  """
    targets_pad_mask = common_attention.embedding_to_padding(targets)
    with tf.name_scope("pad_remover"):
        # Because of the shift_right, the <eos> token will be concidered as
        # padding. In practice, it doesn't really matter, due to the triangular
        # mask, this token should never be attended.
        pad_remover = expert_utils.PadRemover(targets_pad_mask)

    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepended(targets_pad_mask))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                tf.shape(targets)[1]))
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias, pad_remover)
Exemple #3
0
def prepare_decoder(targets, target_space_emb):
    """Prepare decoder."""
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    target_space_emb = tf.reshape(target_space_emb, [1, 1, -1])
    target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1])
    decoder_input = common_layers.shift_right_3d(targets,
                                                 pad_value=target_space_emb)
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Exemple #4
0
def decode(cond_vec, cond_add, gold, c, ed, hparams):
    """Transformer decoder."""
    drop_gold = tf.nn.dropout(gold, 1.0 - hparams.layer_prepostprocess_dropout)
    decoder_input = common_layers.shift_right(drop_gold, pad_value=cond_vec)
    if cond_add is not None:
        decoder_input += cond_add
    decoder_input = tf.squeeze(decoder_input, axis=2)
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1])
    if c is not None and len(c.get_shape()) > 3:
        c = tf.squeeze(c, axis=2)
    return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams)
Exemple #5
0
def attend(x, source, hparams, name):
    with tf.variable_scope(name):
        x = tf.squeeze(x, axis=2)
        if len(source.get_shape()) > 3:
            source = tf.squeeze(source, axis=2)
        source = common_attention.add_timing_signal_1d(source)
        y = common_attention.multihead_attention(
            common_layers.layer_preprocess(x, hparams), source, None,
            hparams.attention_key_channels or hparams.hidden_size,
            hparams.attention_value_channels or hparams.hidden_size,
            hparams.hidden_size, hparams.num_heads, hparams.attention_dropout)
        res = common_layers.layer_postprocess(x, y, hparams)
        return tf.expand_dims(res, axis=2)
Exemple #6
0
def transformer_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Exemple #7
0
def attention_lm_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
  """
    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepended(
                common_attention.embedding_to_padding(targets)))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                tf.shape(targets)[1]))
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)