Beispiel #1
0
def attention_lm_moe_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
    pad_remover (expert_utils.PadRemover): an util object to remove padding
  """
    targets_pad_mask = common_attention.embedding_to_padding(targets)
    with tf.name_scope("pad_remover"):
        # Because of the shift_right, the <eos> token will be concidered as
        # padding. In practice, it doesn't really matter, due to the triangular
        # mask, this token should never be attended.
        pad_remover = expert_utils.PadRemover(targets_pad_mask)

    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepended(targets_pad_mask))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                tf.shape(targets)[1]))
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias, pad_remover)
Beispiel #2
0
def prepare_decoder(targets, target_space_emb):
    """Prepare decoder."""
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    target_space_emb = tf.reshape(target_space_emb, [1, 1, -1])
    target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1])
    decoder_input = common_layers.shift_right_3d(targets,
                                                 pad_value=target_space_emb)
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #3
0
def decode(cond_vec, cond_add, gold, c, ed, hparams):
    """Transformer decoder."""
    drop_gold = tf.nn.dropout(gold, 1.0 - hparams.layer_prepostprocess_dropout)
    decoder_input = common_layers.shift_right(drop_gold, pad_value=cond_vec)
    if cond_add is not None:
        decoder_input += cond_add
    decoder_input = tf.squeeze(decoder_input, axis=2)
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1])
    if c is not None and len(c.get_shape()) > 3:
        c = tf.squeeze(c, axis=2)
    return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams)
Beispiel #4
0
 def testMultiheadSelfAttentionMemoryEfficient(self):
     num_heads = 4
     io_size = 16
     batch = 2
     length = 7
     head_size = 5
     x = np.random.rand(batch, length, io_size)
     dy = np.random.rand(batch, length, io_size)
     with self.test_session() as session:
         x = tf.to_float(x)
         dy = tf.to_float(dy)
         bias = common_attention.attention_bias_lower_triangle(length)
         wqkv = tf.get_variable(
             "wqkv", [num_heads, 1, io_size, 3 * head_size],
             initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
         wo = tf.get_variable("wo", [num_heads, 1, head_size, io_size],
                              initializer=tf.random_normal_initializer(
                                  stddev=(head_size * num_heads)**-0.5))
         norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
         y = common_attention.multihead_self_attention_memory_efficient(
             x,
             bias,
             num_heads,
             head_size=head_size,
             forget=False,
             test_vars=(wqkv, wo, norm_scale, norm_bias))
         y_forget = common_attention.multihead_self_attention_memory_efficient(
             x,
             bias,
             num_heads,
             head_size=head_size,
             forget=True,
             test_vars=(wqkv, wo, norm_scale, norm_bias))
         dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients(
             ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
         dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
             ys=[y_forget],
             xs=[x, wqkv, wo, norm_scale, norm_bias],
             grad_ys=[dy])
         session.run(tf.global_variables_initializer())
         (y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f,
          dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run([
              y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f,
              dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f
          ])
     self.assertAllClose(y, y_forget)
     self.assertAllClose(dwo, dwo_f)
     self.assertAllClose(dwqkv, dwqkv_f)
     self.assertAllClose(dnorm_scale, dnorm_scale_f)
     self.assertAllClose(dnorm_bias, dnorm_bias_f)
     self.assertAllClose(dx, dx_f)
Beispiel #5
0
def transformer_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #6
0
def attention_lm_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
  """
    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepended(
                common_attention.embedding_to_padding(targets)))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                tf.shape(targets)[1]))
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #7
0
    def _greedy_infer(self, features, decode_length, last_position_only=True):
        """Fast version of greedy decoding.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      last_position_only: MUST be true for fast decoding!

    Returns:
       samples: [batch_size, input_length + decode_length]
       logits: Not returned
       losses: Not returned

    Raises:
      ValueError: If last_position_only if False
      NotImplementedError: If there are multiple data shards.
    """
        if not last_position_only:
            raise ValueError(
                "Fast decoding only deals with the last positions!")
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams

        inputs = features["inputs"]
        batch_size = tf.shape(inputs)[0]
        target_modality = self._problem_hparams.target_modality
        if t2t_model.is_class_modality(target_modality):
            decode_length = 1
        else:
            decode_length = tf.shape(inputs)[1] + decode_length

        # TODO(llion): Clean up this reshaping logic.
        inputs = tf.expand_dims(inputs, axis=1)
        if len(inputs.shape) < 5:
            inputs = tf.expand_dims(inputs, axis=4)
        s = tf.shape(inputs)
        inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.input_modality["inputs"]
        with tf.variable_scope(input_modality.name):
            inputs = input_modality.bottom_sharded(inputs, dp)
        with tf.variable_scope("body"):
            encoder_output, encoder_decoder_attention_bias = dp(
                self.encode, inputs, features["target_space_id"], hparams)

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode, targets, encoder_output[0],
                                  encoder_decoder_attention_bias[0], bias,
                                  hparams, cache)

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            return tf.squeeze(logits, axis=[1, 2, 3])

        def inner_loop(i, next_id, decoded_ids, cache):
            logits = symbols_to_logits_fn(next_id, i, cache)
            next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1)
            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return i + 1, next_id, decoded_ids, cache

        key_channels = hparams.attention_key_channels or hparams.hidden_size
        value_channels = hparams.attention_value_channels or hparams.hidden_size
        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

        cache = {
            "layer_%d" % layer: {
                "k": tf.zeros([batch_size, 0, key_channels]),
                "v": tf.zeros([batch_size, 0, value_channels]),
            }
            for layer in range(num_layers)
        }
        decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
        next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
        _, _, decoded_ids, _ = tf.while_loop(
            # TODO(llion): Early stopping.
            lambda i, *_: tf.less(i, decode_length),
            inner_loop,
            [tf.constant(0), next_id, decoded_ids, cache],
            shape_invariants=[
                tf.TensorShape([]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                {
                    "layer_%d" % layer: {
                        "k": tf.TensorShape([None, None, key_channels]),
                        "v": tf.TensorShape([None, None, value_channels]),
                    }
                    for layer in range(num_layers)
                }
            ])

        return decoded_ids, None, None