Exemplo n.º 1
0
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0,
                   sampling_method='random'):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams

    inputs = features["inputs"]
    batch_size = tf.shape(inputs)[0]
    target_modality = self._problem_hparams.target_modality
    if t2t_model.is_class_modality(target_modality):
      decode_length = 1
    else:
      decode_length = tf.shape(inputs)[1] + decode_length

    # TODO(llion): Clean up this reshaping logic.
    inputs = tf.expand_dims(inputs, axis=1)
    if len(inputs.shape) < 5:
      inputs = tf.expand_dims(inputs, axis=4)
    s = tf.shape(inputs)
    inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
    # _shard_features called to ensure that the variable names match
    inputs = self._shard_features({"inputs": inputs})["inputs"]
    input_modality = self._problem_hparams.input_modality["inputs"]
    with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE):
      inputs = input_modality.bottom_sharded(inputs, dp)  # inputs embedding
    with tf.variable_scope("body", reuse=tf.AUTO_REUSE):
      encoder_output, encoder_decoder_attention_bias = dp(
          self.encode, inputs, features["target_space_id"], hparams)   # do self.encode function
    encoder_output = encoder_output[0]
    encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

    if hparams.pos == "timing":
      timing_signal = common_attention.get_timing_signal_1d(
          decode_length + 1, hparams.hidden_size)

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name, reuse=tf.AUTO_REUSE):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      # TODO(llion): Explain! Is this even needed?
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
          decode_length)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body", reuse=tf.AUTO_REUSE):
        body_outputs = dp(self.decode, targets, cache["encoder_output"],
                          cache["encoder_decoder_attention_bias"], bias,
                          hparams, cache)

      with tf.variable_scope(target_modality.name, reuse=tf.AUTO_REUSE):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      return tf.squeeze(logits, axis=[1, 2, 3]), cache

    key_channels = hparams.attention_key_channels or hparams.hidden_size
    value_channels = hparams.attention_value_channels or hparams.hidden_size
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

    cache = {
        "layer_%d" % layer: {
            "k": tf.zeros([batch_size, 0, key_channels]),
            "v": tf.zeros([batch_size, 0, value_channels]),
        }
        for layer in range(num_layers)
    }

    # Set 2nd dim to None since it's not invariant in the tf.while_loop
    # Note: Tensor.set_shape() does not work here since it merges shape info.
    # TODO(llion); Find a more robust solution.
    # pylint: disable=protected-access
    for layer in cache:
      cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels])
      cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels])
    # pylint: enable=protected-access
    cache["encoder_output"] = encoder_output
    cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

    if beam_size > 1:  # Beam Search
      target_modality = (
          self._hparams.problems[self._problem_idx].target_modality)
      vocab_size = target_modality.top_dimensionality
      initial_ids = tf.zeros([batch_size], dtype=tf.int32)
      decoded_ids, scores = beam_search.beam_search(
          symbols_to_logits_fn, initial_ids, beam_size, decode_length,
          vocab_size, alpha, states=cache, stop_early=(top_beams == 1))

      if top_beams == 1:
        decoded_ids = decoded_ids[:, 0, 1:]
      else:
        decoded_ids = decoded_ids[:, :top_beams, 1:]
    else:  # Greedy

      def inner_loop(i, next_id, decoded_ids, cache):
        logits, cache = symbols_to_logits_fn(next_id, i, cache)
        # temperature = (0.0 if hparams.sampling_method == "argmax"
        #                else hparams.sampling_temp)
        temperature = (0.0 if sampling_method == "argmax"
                       else hparams.sampling_temp)
        next_id = tf.expand_dims(   # sample with argmax or multinomial
            common_layers.sample_with_temperature(logits, temperature), axis=1)
        decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
        return i + 1, next_id, decoded_ids, cache

      decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
      scores = None
      next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
      _, _, decoded_ids, _ = tf.while_loop(
          # TODO(llion): Early stopping.
          lambda i, *_: tf.less(i, decode_length),
          inner_loop,
          [tf.constant(0), next_id, decoded_ids, cache],
          shape_invariants=[
              tf.TensorShape([]),
              tf.TensorShape([None, None]),
              tf.TensorShape([None, None]),
              nest.map_structure(lambda t: tf.TensorShape(t.shape), cache),
          ])

    return decoded_ids, scores
Exemplo n.º 2
0
    def _greedy_infer(self, features, decode_length, last_position_only=True):
        """Fast version of greedy decoding.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      last_position_only: MUST be true for fast decoding!

    Returns:
       samples: [batch_size, input_length + decode_length]
       logits: Not returned
       losses: Not returned

    Raises:
      ValueError: If last_position_only if False
      NotImplementedError: If there are multiple data shards.
    """
        if not last_position_only:
            raise ValueError(
                "Fast decoding only deals with the last positions!")
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams

        inputs = features["inputs"]
        batch_size = tf.shape(inputs)[0]
        target_modality = self._problem_hparams.target_modality
        if t2t_model.is_class_modality(target_modality):
            decode_length = 1
        else:
            decode_length = tf.shape(inputs)[1] + decode_length

        # TODO(llion): Clean up this reshaping logic.
        inputs = tf.expand_dims(inputs, axis=1)
        if len(inputs.shape) < 5:
            inputs = tf.expand_dims(inputs, axis=4)
        s = tf.shape(inputs)
        inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.input_modality["inputs"]
        with tf.variable_scope(input_modality.name):
            inputs = input_modality.bottom_sharded(inputs, dp)
        with tf.variable_scope("body"):
            encoder_output, encoder_decoder_attention_bias = dp(
                self.encode, inputs, features["target_space_id"], hparams)

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode, targets, encoder_output[0],
                                  encoder_decoder_attention_bias[0], bias,
                                  hparams, cache)

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            return tf.squeeze(logits, axis=[1, 2, 3])

        def inner_loop(i, next_id, decoded_ids, cache):
            logits = symbols_to_logits_fn(next_id, i, cache)
            next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1)
            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return i + 1, next_id, decoded_ids, cache

        key_channels = hparams.attention_key_channels or hparams.hidden_size
        value_channels = hparams.attention_value_channels or hparams.hidden_size
        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

        cache = {
            "layer_%d" % layer: {
                "k": tf.zeros([batch_size, 0, key_channels]),
                "v": tf.zeros([batch_size, 0, value_channels]),
            }
            for layer in range(num_layers)
        }
        decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
        next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
        _, _, decoded_ids, _ = tf.while_loop(
            # TODO(llion): Early stopping.
            lambda i, *_: tf.less(i, decode_length),
            inner_loop,
            [tf.constant(0), next_id, decoded_ids, cache],
            shape_invariants=[
                tf.TensorShape([]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                {
                    "layer_%d" % layer: {
                        "k": tf.TensorShape([None, None, key_channels]),
                        "v": tf.TensorShape([None, None, value_channels]),
                    }
                    for layer in range(num_layers)
                }
            ])

        return decoded_ids, None, None