Example #1
0
def transformer_prepare_decoder(targets, hparams, features=None):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(
          common_layers.shape_list(targets)[1]))
  if features and "targets_segmentation" in features:
    # "Packed" dataset - keep the examples from seeing each other.
    targets_segmentation = features["targets_segmentation"]
    targets_position = features["targets_position"]
    decoder_self_attention_bias += common_attention.attention_bias_same_segment(
        targets_segmentation, targets_segmentation)
  else:
    targets_position = None
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    if targets_position is not None:
      decoder_input = common_attention.add_timing_signal_1d_given_position(
          decoder_input, targets_position)
    else:
      decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    if features and "inputs_segmentation" in features:
        # Packed dataset.  Keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        targets_segmentation = features["targets_segmentation"]
        encoder_self_attention_bias = common_attention.attention_bias_same_segment(
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = (
            common_attention.attention_bias_same_segment(
                targets_segmentation, inputs_segmentation))
    else:
        # Usual case - not a packed dataset.
        encoder_padding = common_attention.embedding_to_padding(encoder_input)
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(inputs)[1])
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(target_space,
                                               32,
                                               ishape_static[-1],
                                               name="target_space_embedding")
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
    if hparams.pos == "timing":
        if inputs_position is not None:
            encoder_input = common_attention.add_timing_signal_1d_given_position(
                encoder_input, inputs_position)
        else:
            encoder_input = common_attention.add_timing_signal_1d(
                encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Example #3
0
def transformer_prepare_decoder(targets, hparams):
    """Copied from tensor2tensor.models.transformer."""
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Example #4
0
def transformer_prepare_decoder(targets, hparams, features=None):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))
    if features and "targets_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        targets_segmentation = features["targets_segmentation"]
        targets_position = features["targets_position"]
        decoder_self_attention_bias += common_attention.attention_bias_same_segment(
            targets_segmentation, targets_segmentation)
    else:
        targets_position = None
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    #if hparams.pos == "timing":
    #  if targets_position is not None:
    #    decoder_input = common_attention.add_timing_signal_1d_given_position(
    #        decoder_input, targets_position)
    #  else:
    #    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    raw_decoder_input = common_layers.shift_right(features['targets_raw'])
    terminal_decoder_bias, nonterminal_decoder_bias = _get_t_nt_bias(
        raw_decoder_input, hparams, decoder_self_attention_bias)
    raw_decoder_input = tf.squeeze(raw_decoder_input, axis=[-2, -1])
    pos_signals = generate_positional_signals(raw_decoder_input, hparams,
                                              terminal_decoder_bias,
                                              nonterminal_decoder_bias)
    pos_embeddings = generate_positional_embeddings(pos_signals,
                                                    hparams.decoder_pos,
                                                    hparams)
    if "sum" in hparams.decoder_pos_integration:
        decoder_input = decoder_input + pos_embeddings
    elif "ffn" in hparams.decoder_pos_integration:
        with tf.variable_scope("decoder_pos_ffn"):
            decoder_input = tf.concat([decoder_input, pos_embeddings], axis=2)
            decoder_input = transformer_ffn_layer(decoder_input,
                                                  hparams,
                                                  conv_padding="LEFT")
    return (decoder_input, decoder_self_attention_bias, terminal_decoder_bias,
            nonterminal_decoder_bias, pos_signals)
Example #5
0
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
  """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
  ishape_static = inputs.shape.as_list()
  encoder_input = inputs
  if features and "inputs_segmentation" in features:
    # Packed dataset.  Keep the examples from seeing each other.
    inputs_segmentation = features["inputs_segmentation"]
    inputs_position = features["inputs_position"]
    targets_segmentation = features["targets_segmentation"]
    encoder_self_attention_bias = common_attention.attention_bias_same_segment(
        inputs_segmentation, inputs_segmentation)
    encoder_decoder_attention_bias = (
        common_attention.attention_bias_same_segment(
            targets_segmentation, inputs_segmentation))
  else:
    # Usual case - not a packed dataset.
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    inputs_position = None
  if hparams.proximity_bias:
    encoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(inputs)[1])
  # Append target_space_id embedding to inputs.
  emb_target_space = common_layers.embedding(
      target_space, 32, ishape_static[-1], name="target_space_embedding")
  emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
  encoder_input += emb_target_space
  if hparams.pos == "timing":
    if inputs_position is not None:
      encoder_input = common_attention.add_timing_signal_1d_given_position(
          encoder_input, inputs_position)
    else:
      encoder_input = common_attention.add_timing_signal_1d(encoder_input)
  return (encoder_input, encoder_self_attention_bias,
          encoder_decoder_attention_bias)
Example #6
0
def transformer_prepare_decoder(targets, hparams, features=None):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in decoder self-attention
  """
    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepend_inputs_full_attention(
                common_attention.embedding_to_padding(targets)))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                common_layers.shape_list(targets)[1]))

    if features and "targets_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        targets_segmentation = features["targets_segmentation"]
        targets_position = features["targets_position"]
        decoder_self_attention_bias += common_attention.attention_bias_same_segment(
            targets_segmentation, targets_segmentation)
    else:
        targets_position = None
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        if targets_position is not None:
            decoder_input = common_attention.add_timing_signal_1d_given_position(
                decoder_input, targets_position)
        else:
            decoder_input = common_attention.add_timing_signal_1d(
                decoder_input)
    if hparams.activation_dtype == "bfloat16":
        decoder_self_attention_bias = tf.cast(decoder_self_attention_bias,
                                              tf.bfloat16)
    return (decoder_input, decoder_self_attention_bias)
Example #7
0
def transformer_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    return (decoder_input, decoder_self_attention_bias)
Example #8
0
def transformer_prepare_encoder(inputs, target_space, hparams):
    """Prepare one shard of the model for the encoder.
  
    Args:
      inputs: Tensor with shape [batch, memory_length, depth]
      target_space: a Tensor.
      hparams: run hyperparameters
  
    Returns:
      encoder_input: a Tensor, bottom of encoder stack
      encoder_self_attention_bias: a bias tensor for use in encoder self-attention
      encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
        attention
    """

    ignore_padding = get_ignore_padding(inputs)
    encoder_self_attention_bias = ignore_padding

    # Bias for self-attention to encourage attention to close positions.
    if hparams.proximity_bias:
        encoder_self_attention_bias += comm_attn.attention_bias_proximal(
            length=tf.shape(inputs)[1])

    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(
        x=target_space,
        vocab_size=32,
        dense_size=inputs.shape.as_list[-1],
        name='target_space_embedding')
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])

    # Question: wat
    encoder_input = inputs + emb_target_space
    if hparams.pos == 'timing':
        encoder_input = comm_attn.add_timing_signal_1d(encoder_input)
    # Putting this here since always called immediately after...
    encoder_input = with_dropout(encoder_input, hparams)

    return EncoderState(input=encoder_input,
                        self_attn_bias=encoder_self_attention_bias,
                        decoder_attn_bias=ignore_padding,
                        output=None)
Example #9
0
def transformer_prepare_encoder(inputs, target_space, hparams):
    """Copied from tensor2tensor.models.transformer."""
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(inputs)[1])
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(target_space,
                                               32,
                                               ishape_static[-1],
                                               name="target_space_embedding")
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
    if hparams.pos == "timing":
        encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Example #10
0
def transformer_prepare_encoder(inputs, target_space, hparams):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(inputs)[1])
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(
        target_space,
        32,
        ishape_static[-1],
        name="target_space_embedding",
        use_eager_mode=hparams.use_eager_mode)
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
    if hparams.pos == "timing":
        encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Example #11
0
def transformer_prepare_encoder(inputs, target_space, hparams):
  """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
  ishape_static = inputs.shape.as_list()
  encoder_input = inputs
  encoder_padding = common_attention.embedding_to_padding(encoder_input)
  ignore_padding = common_attention.attention_bias_ignore_padding(
      encoder_padding)
  encoder_self_attention_bias = ignore_padding
  encoder_decoder_attention_bias = ignore_padding
  if hparams.proximity_bias:
    encoder_self_attention_bias += common_attention.attention_bias_proximal(
        tf.shape(inputs)[1])
  # Append target_space_id embedding to inputs.
  emb_target_space = common_layers.embedding(
      target_space, 32, ishape_static[-1], name="target_space_embedding")
  emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
  encoder_input += emb_target_space

  # random_uniform_mask = tf.expand_dims(tf.to_float(tf.to_int32(tf.random_uniform([tf.shape(encoder_input)[0], tf.shape(encoder_input)[1]]) < hparams.mask_noise_prob)), axis=2)
  # encoder_input = encoder_input * (1 - random_uniform_mask)

  if hparams.pos == "timing":
    encoder_input = common_attention.add_timing_signal_1d(encoder_input)
  return (encoder_input, encoder_self_attention_bias,
          encoder_decoder_attention_bias)
Example #12
0
def transformer_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.
  
    Args:
      targets: a Tensor.
      hparams: run hyperparameters
  
    Returns:
      decoder_input: a Tensor, bottom of decoder stack
      decoder_self_attention_bias: a bias tensor for use in encoder self-attention
    """
    decoder_self_attention_bias = (comm_attn.attention_bias_lower_triangle(
        tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += comm_attn.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_left_3d(targets)
    if hparams.pos == 'timing':
        decoder_input = comm_attn.add_timing_signal_1d(decoder_input)
    # Putting this here since always called immediately after...
    decoder_input = with_dropout(decoder_input, hparams)

    return DecoderState(input=decoder_input,
                        self_attn_bias=decoder_self_attention_bias)
Example #13
0
def transformer_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        tf.shape(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  # decoder_input = tf.Print(decoder_input, [tf.shape(decoder_input)], 
  #     summarize=1000, message="decoder_input")
  # decoder_self_attention_bias = tf.Print(decoder_self_attention_bias, [tf.shape(decoder_self_attention_bias)], 
  #     summarize=1000, message="decoder_self_attention_bias")
  return (decoder_input, decoder_self_attention_bias)
Example #14
0
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
  """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
  ishape_static = inputs.shape.as_list()
  encoder_input = inputs
  if features and "inputs_segmentation" in features:
    # Packed dataset.  Keep the examples from seeing each other.
    inputs_segmentation = features["inputs_segmentation"]
    inputs_position = features["inputs_position"]
    targets_segmentation = features["targets_segmentation"]
    if (hasattr(hparams, "unidirectional_encoder") and
        hparams.unidirectional_encoder):
      tf.logging.info("Using unidirectional encoder")
      encoder_self_attention_bias = (
          common_attention.attention_bias_lower_triangle(
              common_layers.shape_list(inputs)[1]))
    else:
      encoder_self_attention_bias = (
          common_attention.attention_bias_same_segment(
              inputs_segmentation, inputs_segmentation))
    encoder_decoder_attention_bias = (
        common_attention.attention_bias_same_segment(targets_segmentation,
                                                     inputs_segmentation))
  else:
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    if (hasattr(hparams, "unidirectional_encoder") and
        hparams.unidirectional_encoder):
      tf.logging.info("Using unidirectional encoder")
      encoder_self_attention_bias = (
          common_attention.attention_bias_lower_triangle(
              common_layers.shape_list(inputs)[1]))
    else:
      # Usual case - not a packed dataset.
      encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    inputs_position = None
  if hparams.proximity_bias:
    encoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(inputs)[1])
  if hparams.get("use_target_space_embedding", True):
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(
        target_space,
        32,
        ishape_static[-1],
        name="target_space_embedding",
        dtype=tf.bfloat16
        if hparams.activation_dtype == "bfloat16" else tf.float32)
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
  if hparams.pos == "timing":
    if inputs_position is not None:
      encoder_input = common_attention.add_timing_signal_1d_given_position(
          encoder_input, inputs_position)
    else:
      encoder_input = common_attention.add_timing_signal_1d(encoder_input)
  elif hparams.pos == "emb":
    encoder_input = common_attention.add_positional_embedding(
        encoder_input, hparams.max_length, "inputs_positional_embedding",
        inputs_position)
  if hparams.activation_dtype == "bfloat16":
    encoder_self_attention_bias = tf.cast(encoder_self_attention_bias,
                                          tf.bfloat16)
    encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
                                             tf.bfloat16)
  return (encoder_input, encoder_self_attention_bias,
          encoder_decoder_attention_bias)
Example #15
0
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0,
                     sentence_cache=None):
        """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.target_modality

        if self.has_input:
            inputs = features["inputs"]
            if target_modality.is_class_modality:
                decode_length = 1
            else:
                decode_length = common_layers.shape_list(
                    inputs)[1] + decode_length

            # TODO(llion): Clean up this reshaping logic.
            inputs = tf.expand_dims(inputs, axis=1)
            if len(inputs.shape) < 5:
                inputs = tf.expand_dims(inputs, axis=4)
            s = common_layers.shape_list(inputs)
            batch_size = s[0]
            inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
            # _shard_features called to ensure that the variable names match
            inputs = self._shard_features({"inputs": inputs})["inputs"]
            input_modality = self._problem_hparams.input_modality["inputs"]
            with tf.variable_scope(input_modality.name):
                inputs = input_modality.bottom_sharded(inputs, dp)
            with tf.variable_scope("body"):
                encoder_output, encoder_decoder_attention_bias = dp(
                    self.encode,
                    inputs,
                    features["target_space_id"],
                    hparams,
                    features=features)
            encoder_output = encoder_output[0]
            encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
            partial_targets = None
        else:
            # The problem has no inputs.
            # In this case, features["inputs"] contains partial targets.
            # We force the outputs to begin with these sequences.
            encoder_output = None
            encoder_decoder_attention_bias = None
            partial_targets = tf.squeeze(tf.to_int64(features["inputs"]),
                                         [2, 3])
            partial_targets_length = common_layers.shape_list(
                partial_targets)[1]
            decode_length += partial_targets_length
            batch_size = tf.shape(partial_targets)[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode,
                                  targets,
                                  cache.get("encoder_output"),
                                  cache.get("encoder_decoder_attention_bias"),
                                  bias,
                                  hparams,
                                  cache,
                                  nonpadding=features_to_nonpadding(
                                      features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            ret = tf.squeeze(logits, axis=[1, 2, 3])
            if partial_targets is not None:
                # If the position is within the given partial targets, we alter the
                # logits to always return those values.
                # A faster approach would be to process the partial targets in one
                # iteration in order to fill the corresponding parts of the cache.
                # This would require broader changes, though.
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size, 0.0, -1e9)

                ret = tf.cond(tf.less(i, partial_targets_length),
                              forced_logits, lambda: ret)
            return ret, cache, body_outputs

        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size,
            sentence_cache=self.sentence_cache,
            cache_flag=self.cache_flag)
        if partial_targets is not None:
            ret["outputs"] = ret["outputs"][:, partial_targets_length:]
        return ret
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0):
    """ Fast decoding
    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.
    :param features: a map of string to model  features.
    :param decode_length:
    :param beam_size: beam search size
    :param top_beams: an integer, how many of the beams to return
    :param alpha:
    :return:
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams
    target_modality = self._problem_hparams.target_modality

    assert self.has_input, "problems for dual-transformer must have inputs"
    # decode with an input source(needs encoder outputs)
    wav_inputs = features["wav_inputs"]
    txt_inputs = features["txt_inputs"]
    if target_modality.is_class_modality:
      decode_length = 1
    else:
      decode_length = (common_layers.shape_list(wav_inputs)[1] +
                       features.get("decode_length", decode_length))

    wav_inputs = tf.expand_dims(wav_inputs, axis=1)
    txt_inputs = tf.expand_dims(txt_inputs, axis=1)
    if len(wav_inputs.shape) < 5:
      wav_inputs = tf.expand_dims(wav_inputs, axis=4)
    if len(txt_inputs.shape) < 5:
      txt_inputs = tf.expand_dims(txt_inputs, axis=4)

    s = common_layers.shape_list(wav_inputs)
    batch_size = s[0]
    wav_inputs = tf.reshape(wav_inputs, [s[0] * s[1], s[2], s[3], s[4]])
    txt_inputs = tf.reshape(txt_inputs, [s[0] * s[1], s[2], s[3], s[4]])
    # _shard_features called to ensure that the variable names match
    wav_inputs = self._shard_features({"wav_inputs": wav_inputs})["wav_inputs"]
    wav_input_modality = self._problem_hparams.input_modality["wav_inputs"]
    txt_inputs = self._shard_features({"txt_inputs": txt_inputs})["txt_inputs"]
    txt_input_modality = self._problem_hparams.input_modality["txt_inputs"]

    with tf.variable_scope(wav_input_modality.name):
      wav_inputs = wav_input_modality.bottom_sharded(wav_inputs, dp)
    with tf.variable_scope(txt_input_modality.name):
      txt_inputs = txt_input_modality.bottom_sharded(txt_inputs, dp)

    with tf.variable_scope("body"):
      wav_enc_output, wav_enc_dec_attention_bias, \
      txt_enc_output,txt_enc_dec_attention_bias = dp(
        self.dual_encode, wav_inputs, txt_inputs, features["target_space_id"], hparams,
        features=features)
    wav_enc_output = wav_enc_output[0]
    txt_enc_output = txt_enc_output[0]
    wav_enc_dec_attention_bias = wav_enc_dec_attention_bias[0]
    txt_enc_dec_attention_bias = txt_enc_dec_attention_bias[0]

    if hparams.pos == "timing":
      timing_signal = common_attention.get_timing_signal_1d(
        decode_length + 1, hparams.hidden_size)

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      # TODO(llion): Explain! Is this even needed?
      targets = tf.cond(
        tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
        decode_length)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
          self.dual_decode, targets,
          cache.get("wav_enc_outputs"),
          cache.get("txt_enc_outputs"),
          cache.get("wav_enc_dec_attention_bias"),
          cache.get("txt_enc_dec_attention_bias"),
          bias, hparams, cache,
          nonpadding=features_to_nonpadding(features, "targets"))

      with tf.variable_scope(target_modality.name):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      ret = tf.squeeze(logits, axis=[1, 2, 3])

      return ret, cache

    ret = fast_decode(
      wav_encoder_output=wav_enc_output,
      txt_encoder_output=txt_enc_output,
      wav_enc_dec_attention_bias=wav_enc_dec_attention_bias,
      txt_enc_dec_attention_bias=txt_enc_dec_attention_bias,
      symbols_to_logits_fn=symbols_to_logits_fn,
      hparams=hparams,
      decode_length=decode_length,
      vocab_size=target_modality.top_dimensionality,
      beam_size=beam_size,
      top_beams=top_beams,
      alpha=alpha,
      batch_size=batch_size,
      force_decode_length=self._decode_hparams.force_decode_length)

    return ret
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams

        inputs = features["inputs"]
        target_modality = self._problem_hparams.target_modality
        if target_modality.is_class_modality:
            decode_length = 1
        else:
            decode_length = common_layers.shape_list(inputs)[1] + decode_length

        # TODO(llion): Clean up this reshaping logic.
        inputs = tf.expand_dims(inputs, axis=1)
        if len(inputs.shape) < 5:
            inputs = tf.expand_dims(inputs, axis=4)
        s = common_layers.shape_list(inputs)
        inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.input_modality["inputs"]
        with tf.variable_scope(input_modality.name):
            inputs = input_modality.bottom_sharded(inputs, dp)
        with tf.variable_scope("body"):
            encoder_output, encoder_decoder_attention_bias = dp(
                self.encode,
                inputs,
                features["target_space_id"],
                hparams,
                features=features)
        encoder_output = encoder_output[0]
        encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode,
                                  targets,
                                  cache["encoder_output"],
                                  cache["encoder_decoder_attention_bias"],
                                  bias,
                                  hparams,
                                  cache,
                                  nonpadding=features_to_nonpadding(
                                      features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            return tf.squeeze(logits, axis=[1, 2, 3]), cache

        return fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha)
Example #18
0
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for longer translations.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams
    target_modality = self._problem_hparams.modality["targets"]
    target_vocab_size = self._problem_hparams.vocab_size["targets"]
    if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
      target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor
    if "targets_segmentation" in features:
      raise NotImplementedError(
          "Decoding not supported on packed datasets "
          " If you want to decode from a dataset, use the non-packed version"
          " of the dataset when decoding.")
    if self.has_input:
      inputs = features["inputs"]
      if target_modality == modalities.ModalityType.CLASS_LABEL:
        decode_length = 1
      else:
        decode_length = (
            common_layers.shape_list(inputs)[1] + features.get(
                "decode_length", decode_length))

      # TODO(llion): Clean up this reshaping logic.
      inputs = tf.expand_dims(inputs, axis=1)
      if len(inputs.shape) < 5:
        inputs = tf.expand_dims(inputs, axis=4)
      s = common_layers.shape_list(inputs)
      batch_size = s[0]
      inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
      # _shard_features called to ensure that the variable names match
      inputs = self._shard_features({"inputs": inputs})["inputs"]
      input_modality = self._problem_hparams.modality["inputs"]
      input_vocab_size = self._problem_hparams.vocab_size["inputs"]
      if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
        input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor
      modality_name = hparams.name.get(
          "inputs",
          modalities.get_name(input_modality))(hparams, input_vocab_size)
      with tf.variable_scope(modality_name):
        bottom = hparams.bottom.get("inputs",
                                    modalities.get_bottom(input_modality))
        inputs = dp(bottom, inputs, hparams, input_vocab_size)
      with tf.variable_scope("body"):
        encoder_output, encoder_decoder_attention_bias = dp(
            self.encode,
            inputs,
            features["target_space_id"],
            hparams,
            features=features)
      encoder_output = encoder_output[0]
      encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
      if 'partial_targets' in features:
        partial_targets = features['partial_targets']
      else:
        partial_targets = None
    else:
      # The problem has no inputs.
      encoder_output = None
      encoder_decoder_attention_bias = None

      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs")
      if partial_targets is None:
        partial_targets = features["targets"]
      assert partial_targets is not None

    if partial_targets is not None:
      partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
      partial_targets = tf.to_int64(partial_targets)
      partial_targets_shape = common_layers.shape_list(partial_targets)
      partial_targets_length = partial_targets_shape[1]
      decode_length = (
          partial_targets_length + features.get("decode_length", decode_length))
      batch_size = partial_targets_shape[0]

    if hparams.pos == "timing":
      positional_encoding = common_attention.get_timing_signal_1d(
          decode_length + 1, hparams.hidden_size)
    elif hparams.pos == "emb":
      positional_encoding = common_attention.add_positional_embedding(
          tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length,
          "body/targets_positional_embedding", None)
    else:
      positional_encoding = None

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      modality_name = hparams.name.get(
          "targets",
          modalities.get_name(target_modality))(hparams, target_vocab_size)
      with tf.variable_scope(modality_name):
        bottom = hparams.bottom.get(
            "targets", modalities.get_targets_bottom(target_modality))
        targets = dp(bottom, targets, hparams, target_vocab_size)[0]
      targets = common_layers.flatten4d3d(targets)

      # GO embeddings are all zero, this is because transformer_prepare_decoder
      # Shifts the targets along by one for the input which pads with zeros.
      # If the modality already maps GO to the zero embeddings this is not
      # needed.
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if positional_encoding is not None:
        targets += positional_encoding[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
          decode_length)

    # Create tensors for encoder-decoder attention history
    att_cache = {"attention_history": {}}
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
    att_batch_size, enc_seq_length = common_layers.shape_list(encoder_output)[0:2]
    for layer in range(num_layers):
      att_cache["attention_history"]["layer_%d" % layer] = tf.zeros(
        [att_batch_size, hparams.num_heads, 0, enc_seq_length])
    att_cache["body_outputs"] = tf.zeros([att_batch_size, 1, 0, hparams.hidden_size])

    def update_decoder_attention_history(cache):
      for k in filter(lambda x: "decoder" in x and not "self" in x and not "logits" in x,
        self.attention_weights.keys()):
        m = re.search(r"(layer_\d+)", k)
        if m is None:
          continue
        cache["attention_history"][m[0]] = tf.concat(
            [cache["attention_history"][m[0]], self.attention_weights[k]], axis=2)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
            self.decode,
            targets,
            cache.get("encoder_output"),
            cache.get("encoder_decoder_attention_bias"),
            bias,
            hparams,
            cache,
            nonpadding=features_to_nonpadding(features, "targets"))

      update_decoder_attention_history(cache)
      cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2)

      modality_name = hparams.name.get(
          "targets",
          modalities.get_name(target_modality))(hparams, target_vocab_size)
      with tf.variable_scope(modality_name):
        top = hparams.top.get("targets", modalities.get_top(target_modality))
        logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]

      ret = tf.squeeze(logits, axis=[1, 2, 3])
      if partial_targets is not None:
        # If the position is within the given partial targets, we alter the
        # logits to always return those values.
        # A faster approach would be to process the partial targets in one
        # iteration in order to fill the corresponding parts of the cache.
        # This would require broader changes, though.
        vocab_size = tf.shape(ret)[1]

        def forced_logits():
          return tf.one_hot(
              tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
              -1e9)

        ret = tf.cond(
            tf.less(i, partial_targets_length), forced_logits, lambda: ret)
      return ret, cache

    ret = fast_decode(
        encoder_output=encoder_output,
        encoder_decoder_attention_bias=encoder_decoder_attention_bias,
        symbols_to_logits_fn=symbols_to_logits_fn,
        hparams=hparams,
        decode_length=decode_length,
        vocab_size=target_vocab_size,
        beam_size=beam_size,
        top_beams=top_beams,
        alpha=alpha,
        batch_size=batch_size,
        force_decode_length=self._decode_hparams.force_decode_length,
        cache=att_cache)

    if partial_targets is not None:
      if beam_size <= 1 or top_beams <= 1:
        ret["outputs"] = ret["outputs"][:, partial_targets_length:]
      else:
        ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
    return ret
Example #19
0
    def _fast_decode(self,
                     features,
                     decode_length,
                     last_position_only=True,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      last_position_only: MUST be true for fast decoding!
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search

    Raises:
      ValueError: If last_position_only if False
      NotImplementedError: If there are multiple data shards.
    """
        if not last_position_only:
            raise ValueError(
                "Fast decoding only deals with the last positions!")
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams

        inputs = features["inputs"]
        batch_size = tf.shape(inputs)[0]
        target_modality = self._problem_hparams.target_modality
        if t2t_model.is_class_modality(target_modality):
            decode_length = 1
        else:
            decode_length = tf.shape(inputs)[1] + decode_length

        # TODO(llion): Clean up this reshaping logic.
        inputs = tf.expand_dims(inputs, axis=1)
        if len(inputs.shape) < 5:
            inputs = tf.expand_dims(inputs, axis=4)
        s = tf.shape(inputs)
        inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.input_modality["inputs"]
        with tf.variable_scope(input_modality.name):
            inputs = input_modality.bottom_sharded(inputs, dp)
        with tf.variable_scope("body"):
            encoder_output, encoder_decoder_attention_bias = dp(
                self.encode, inputs, features["target_space_id"], hparams)
        encoder_output = encoder_output[0]
        encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode, targets,
                                  cache["encoder_output"],
                                  cache["encoder_decoder_attention_bias"],
                                  bias, hparams, cache)

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            return tf.squeeze(logits, axis=[1, 2, 3]), cache

        key_channels = hparams.attention_key_channels or hparams.hidden_size
        value_channels = hparams.attention_value_channels or hparams.hidden_size
        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

        cache = {
            "layer_%d" % layer: {
                "k": tf.zeros([batch_size, 0, key_channels]),
                "v": tf.zeros([batch_size, 0, value_channels]),
            }
            for layer in range(num_layers)
        }

        # Set 2nd dim to None since it's not invariant in the tf.while_loop
        # Note: Tensor.set_shape() does not work here since it merges shape info.
        # TODO(llion); Find a more robust solution.
        # pylint: disable=protected-access
        for layer in cache:
            cache[layer]["k"]._shape = tf.TensorShape(
                [None, None, key_channels])
            cache[layer]["v"]._shape = tf.TensorShape(
                [None, None, value_channels])
        # pylint: enable=protected-access
        cache["encoder_output"] = encoder_output
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        if beam_size > 1:  # Beam Search
            target_modality = (
                self._hparams.problems[self._problem_idx].target_modality)
            vocab_size = target_modality.top_dimensionality
            initial_ids = tf.zeros([batch_size], dtype=tf.int32)
            decoded_ids, _ = beam_search.beam_search(symbols_to_logits_fn,
                                                     initial_ids,
                                                     beam_size,
                                                     decode_length,
                                                     vocab_size,
                                                     alpha,
                                                     states=cache)

            if top_beams == 1:
                decoded_ids = decoded_ids[:, 0, 1:]
            else:
                decoded_ids = decoded_ids[:, :top_beams, 1:]
        else:  # Greedy

            def inner_loop(i, next_id, decoded_ids, cache):
                logits, cache = symbols_to_logits_fn(next_id, i, cache)
                next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1)
                decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
                return i + 1, next_id, decoded_ids, cache

            decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
            next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
            _, _, decoded_ids, _ = tf.while_loop(
                # TODO(llion): Early stopping.
                lambda i, *_: tf.less(i, decode_length),
                inner_loop,
                [tf.constant(0), next_id, decoded_ids, cache],
                shape_invariants=[
                    tf.TensorShape([]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                    nest.map_structure(lambda t: tf.TensorShape(t.shape),
                                       cache),
                ])

        return decoded_ids
Example #20
0
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        """Fast decoding.
        
        Overrides tensor2tensor.models.transformer.Transformer._fast_decode
        to let symbols_to_logits_fn return multiple things.
    
        Implements both greedy and beam search decoding, uses beam search iff
        beam_size > 1, otherwise beam search related arguments are ignored.
    
        Args:
          features: a map of string to model  features.
          decode_length: an integer.  How many additional timesteps to decode.
          beam_size: number of beams.
          top_beams: an integer. How many of the beams to return.
          alpha: Float that controls the length penalty. larger the alpha, stronger
            the preference for longer translations.
    
        Returns:
          A dict of decoding results {
              "body_output": tensor of size
                  [batch_size, <= decode_length, hidden_size]
                  (or [batch_size, top_beams, <= decode_length, hidden_size])
                  giving the raw output of the Transformer decoder corresponding
                  to the predicted sequences
              "outputs": integer `Tensor` of decoded ids of shape
                  [batch_size, <= decode_length] if beam_size == 1 or
                  [batch_size, top_beams, <= decode_length]
              "scores": decoding log probs from the beam search,
                  None if using greedy decoding (beam_size=1)
          }
    
        Raises:
          NotImplementedError: If there are multiple data shards.
        """
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.target_modality
        if isinstance(target_modality, dict):
            primary_target_feature = self._problem_hparams.primary_target_modality
            primary_target_modality = target_modality[primary_target_feature]
            bottom_variable_scope = "%s/%s" % (primary_target_modality.name,
                                               primary_target_feature)
        else:
            primary_target_feature = "targets"
            primary_target_modality = target_modality
            bottom_variable_scope = target_modality.name

        if self.has_input:
            inputs = features["inputs"]
            if primary_target_modality.is_class_modality:
                decode_length = 1
            else:
                decode_length = (common_layers.shape_list(inputs)[1] +
                                 features.get("decode_length", decode_length))

            # TODO(llion): Clean up this reshaping logic.
            inputs = tf.expand_dims(inputs, axis=1)
            if len(inputs.shape) < 5:
                inputs = tf.expand_dims(inputs, axis=4)
            s = common_layers.shape_list(inputs)
            batch_size = s[0]
            inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
            # _shard_features called to ensure that the variable names match
            inputs = self._shard_features({"inputs": inputs})["inputs"]
            input_modality = self._problem_hparams.input_modality["inputs"]
            with tf.variable_scope(input_modality.name):
                inputs = input_modality.bottom_sharded(inputs, dp)
            with tf.variable_scope("body"):
                encoder_output, encoder_decoder_attention_bias = dp(
                    self.encode,
                    inputs,
                    features["target_space_id"],
                    hparams,
                    features=features)
            encoder_output = encoder_output[0]
            encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
            partial_targets = None
        else:
            # The problem has no inputs.
            encoder_output = None
            encoder_decoder_attention_bias = None

            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs")
            if partial_targets is None:
                partial_targets = features[primary_target_feature]
            assert partial_targets is not None
            partial_targets = common_layers.expand_squeeze_to_nd(
                partial_targets, 2)
            partial_targets = tf.to_int64(partial_targets)
            partial_targets_shape = common_layers.shape_list(partial_targets)
            partial_targets_length = partial_targets_shape[1]
            decode_length = (partial_targets_length +
                             features.get("decode_length", decode_length))
            batch_size = partial_targets_shape[0]

        if hparams.pos == "timing":
            positional_encoding = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)
        elif hparams.pos == "emb":
            positional_encoding = common_attention.add_positional_embedding(
                tf.zeros([1, decode_length + 1, hparams.hidden_size]),
                hparams.max_length, "targets_positional_embedding", None)
        else:
            positional_encoding = None

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.
    
              This includes:
                - Embedding the ids.
                - Flattening to 3D tensor.
                - Optionally adding timing signals.
    
              Args:
                targets: inputs ids to the decoder. [batch_size, 1]
                i: scalar, Step number of the decoding loop.
    
              Returns:
                Processed targets [batch_size, 1, hidden_dim]
            """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({primary_target_feature:
                                            targets})[primary_target_feature]
            with tf.variable_scope(bottom_variable_scope):
                targets = primary_target_modality.targets_bottom_sharded(
                    targets, dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # At step 0, targets will have 0 size, and instead we want to
            # create an embedding of all-zero, corresponding to the start symbol
            # this matches what transformer_prepare_decoder does to the target
            # outputs during training
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if positional_encoding is not None:
                targets += positional_encoding[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            logits = self._symbols_to_logits_fn(targets, features, bias, cache)

            logits = tf.squeeze(logits, axis=[1, 2, 3])
            if partial_targets is not None:
                # If the position is within the given partial targets, we alter the
                # logits to always return those values.
                # A faster approach would be to process the partial targets in one
                # iteration in order to fill the corresponding parts of the cache.
                # This would require broader changes, though.
                vocab_size = tf.shape(logits)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size, 0.0, -1e9)

                logits = tf.cond(tf.less(i, partial_targets_length),
                                 forced_logits, lambda: logits)
            return logits, cache

        cache = dict()
        infer_out = dict()
        if encoder_output is not None:
            padding_mask = 1. - common_attention.attention_bias_to_padding(
                encoder_decoder_attention_bias)
            masked_encoded_output = encoder_output * tf.expand_dims(
                padding_mask, axis=2)

            infer_out["encoded_inputs"] = tf.reduce_sum(masked_encoded_output,
                                                        axis=1)

        self._prepare_decoder_cache(batch_size, features, cache)

        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=primary_target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size,
            force_decode_length=self._decode_hparams.force_decode_length,
            cache=cache)
        infer_out.update(ret)
        if "cache" in ret:
            infer_out.update(ret["cache"])

        if partial_targets is not None:
            if beam_size <= 1 or top_beams <= 1:
                infer_out["outputs"] = infer_out[
                    "outputs"][:, partial_targets_length:]
            else:
                infer_out["outputs"] = infer_out[
                    "outputs"][:, :, partial_targets_length:]

        return infer_out
def transformer_prepare_encoder(inputs,
                                target_space,
                                hparams,
                                features=None,
                                type_ids=None,
                                num_types=None,
                                reuse_target_embedding=tf.AUTO_REUSE):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.
    type_ids: optional, an int64 Tensor of shape [batch, length] that allows
      for adding type embeddings, similar to positional embeddings.
    num_types: optional, an int that decides the number of types in type_ids.
    reuse_target_embedding: option to reuse variable name in the case that
      symbol modalities are reused between inputs/targets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    if features and "inputs_segmentation" in features:
        # Packed dataset.  Keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        targets_segmentation = features["targets_segmentation"]
        if (hasattr(hparams, "unidirectional_encoder")
                and hparams.unidirectional_encoder):
            tf.logging.info("Using unidirectional encoder")
            encoder_self_attention_bias = (
                common_attention.attention_bias_lower_triangle(
                    common_layers.shape_list(inputs)[1]))
        else:
            encoder_self_attention_bias = (
                common_attention.attention_bias_same_segment(
                    inputs_segmentation, inputs_segmentation))
        encoder_decoder_attention_bias = (
            common_attention.attention_bias_same_segment(
                targets_segmentation, inputs_segmentation))
    else:
        encoder_padding = common_attention.embedding_to_padding(encoder_input)
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        if (hasattr(hparams, "unidirectional_encoder")
                and hparams.unidirectional_encoder):
            tf.logging.info("Using unidirectional encoder")
            encoder_self_attention_bias = (
                common_attention.attention_bias_lower_triangle(
                    common_layers.shape_list(inputs)[1]))
        else:
            # Usual case - not a packed dataset.
            encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(inputs)[1])
    if target_space is not None and hparams.get("use_target_space_embedding",
                                                True):
        # Append target_space_id embedding to inputs.
        emb_target_space = common_layers.embedding(
            target_space,
            32,
            ishape_static[-1],
            name="target_space_embedding",
            dtype=hparams.get("activation_dtype", "float32"),
            reuse=reuse_target_embedding)
        emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
        encoder_input += emb_target_space
    if hparams.pos == "timing":
        if inputs_position is not None:
            encoder_input = common_attention.add_timing_signal_1d_given_position(
                encoder_input, inputs_position)
        else:
            encoder_input = common_attention.add_timing_signal_1d(
                encoder_input)
    elif hparams.pos == "timing_from_features":
        encoder_input = common_attention.add_timing_signals_from_features(
            encoder_input, features, hparams.position_features)
    elif hparams.pos == "emb":
        encoder_input = common_attention.add_positional_embedding(
            encoder_input, hparams.max_length, "inputs_positional_embedding",
            inputs_position)

    # Add type embeddings
    if type_ids is not None:
        if not num_types:
            raise ValueError("Need to set num_types as well.")
        encoder_input = common_attention.add_positional_embedding(
            encoder_input, num_types, "inputs_type_embedding", type_ids)

    encoder_self_attention_bias = common_layers.cast_like(
        encoder_self_attention_bias, encoder_input)
    encoder_decoder_attention_bias = common_layers.cast_like(
        encoder_decoder_attention_bias, encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Example #22
0
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
      sg: inputs here have been flattened to 3d
        [batch, height, width, embed_size] ->
        [batch, height*width, embed_size]
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    if features and "inputs_segmentation" in features:
        # Packed dataset.  Keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        targets_segmentation = features["targets_segmentation"]
        encoder_self_attention_bias = common_attention.attention_bias_same_segment(
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = (
            common_attention.attention_bias_same_segment(
                targets_segmentation, inputs_segmentation))
    else:
        # Usual case - not a packed dataset.
        encoder_padding = common_attention.embedding_to_padding(encoder_input)
        # sg: [batch_size, sentence_len]
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        # sg: [batch_size, 1, 1, sentence_len]
        # an bias tensor to be added to attention logits
        # for padded words, the biases equal -1e9
        # non padded words equal 0
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(inputs)[1])
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(
        target_space,
        32,
        # sg: 32 vocab_size (comments in fun, may be not exactly)
        # this is because at current time t2t only have
        # SpaceID in problem.py from 1 to 32
        ishape_static[-1],
        # sg: embedding dimension
        name="target_space_embedding",
        dtype=tf.bfloat16
        if hparams.activation_dtype == "bfloat16" else tf.float32)
    # sg: [1,128] a dense vector to represent SpaceID
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    # sg: [1,1,128]
    encoder_input += emb_target_space
    if hparams.pos == "timing":
        if inputs_position is not None:
            encoder_input = common_attention.add_timing_signal_1d_given_position(
                encoder_input, inputs_position)
        else:
            encoder_input = common_attention.add_timing_signal_1d(
                encoder_input)
    if hparams.activation_dtype == "bfloat16":
        encoder_self_attention_bias = tf.cast(encoder_self_attention_bias,
                                              tf.bfloat16)
        encoder_decoder_attention_bias = tf.cast(
            encoder_decoder_attention_bias, tf.bfloat16)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Example #23
0
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams
    target_modality = self._problem_hparams.target_modality

    if self.has_input:
      inputs = features["inputs"]
      if target_modality.is_class_modality:
        decode_length = 1
      else:
        decode_length = common_layers.shape_list(inputs)[1] + decode_length

      # TODO(llion): Clean up this reshaping logic.
      inputs = tf.expand_dims(inputs, axis=1)
      if len(inputs.shape) < 5:
        inputs = tf.expand_dims(inputs, axis=4)
      s = common_layers.shape_list(inputs)
      batch_size = s[0]
      inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
      # _shard_features called to ensure that the variable names match
      inputs = self._shard_features({"inputs": inputs})["inputs"]
      input_modality = self._problem_hparams.input_modality["inputs"]
      with tf.variable_scope(input_modality.name):
        inputs = input_modality.bottom_sharded(inputs, dp)
      with tf.variable_scope("body"):
        encoder_output, encoder_decoder_attention_bias = dp(
            self.encode, inputs, features["target_space_id"], hparams,
            features=features)
      encoder_output = encoder_output[0]
      encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
      partial_targets = None
    else:
      # The problem has no inputs.
      # In this case, features["inputs"] contains partial targets.
      # We force the outputs to begin with these sequences.
      encoder_output = None
      encoder_decoder_attention_bias = None
      partial_targets = tf.squeeze(tf.to_int64(features["inputs"]), [2, 3])
      partial_targets_length = common_layers.shape_list(partial_targets)[1]
      decode_length += partial_targets_length
      batch_size = tf.shape(partial_targets)[0]

    if hparams.pos == "timing":
      timing_signal = common_attention.get_timing_signal_1d(
          decode_length + 1, hparams.hidden_size)

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      # TODO(llion): Explain! Is this even needed?
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
          decode_length)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
            self.decode, targets, cache.get("encoder_output"),
            cache.get("encoder_decoder_attention_bias"),
            bias, hparams, cache,
            nonpadding=features_to_nonpadding(features, "targets"))

      with tf.variable_scope(target_modality.name):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      ret = tf.squeeze(logits, axis=[1, 2, 3])
      if partial_targets is not None:
        # If the position is within the given partial targets, we alter the
        # logits to always return those values.
        # A faster approach would be to process the partial targets in one
        # iteration in order to fill the corresponding parts of the cache.
        # This would require broader changes, though.
        vocab_size = tf.shape(ret)[1]
        def forced_logits():
          return tf.one_hot(tf.tile(partial_targets[:, i], [beam_size]),
                            vocab_size, 0.0, -1e9)
        ret = tf.cond(
            tf.less(i, partial_targets_length), forced_logits, lambda: ret)
      return ret, cache

    ret = fast_decode(
        encoder_output=encoder_output,
        encoder_decoder_attention_bias=encoder_decoder_attention_bias,
        symbols_to_logits_fn=symbols_to_logits_fn,
        hparams=hparams,
        decode_length=decode_length,
        vocab_size=target_modality.top_dimensionality,
        beam_size=beam_size,
        top_beams=top_beams,
        alpha=alpha,
        batch_size=batch_size)
    if partial_targets is not None:
      ret["outputs"] = ret["outputs"][:, partial_targets_length:]
    return ret
Example #24
0
  def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1,
                   alpha=1.0):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha,
      stronger
        the preference for slonger translations.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams
    target_modality = self._problem_hparams.target_modality

    story = features[babi_qa.FeatureNames.STORY]
    question = features[babi_qa.FeatureNames.QUESTION]

    if target_modality.is_class_modality:
      decode_length = 1
    else:
      decode_length = (common_layers.shape_list(story)[1] +
                       common_layers.shape_list(question)[1] + decode_length)

    story = tf.expand_dims(story, axis=1)
    question = tf.expand_dims(question, axis=1)

    if len(story.shape) < 5:
      story = tf.expand_dims(story, axis=4)

    if len(question.shape) < 5:
      question = tf.expand_dims(question, axis=4)

    s = common_layers.shape_list(story)
    batch_size = s[0]
    story = tf.reshape(story, [s[0] * s[1], s[2], s[3], s[4]])

    s = common_layers.shape_list(question)
    batch_size = s[0]

    question = tf.reshape(question, [s[0] * s[1], s[2], s[3], s[4]])

    # _shard_features called to ensure that the variable names match
    story = self._shard_features({babi_qa.FeatureNames.STORY: story}
                                 )[babi_qa.FeatureNames.STORY]

    question = self._shard_features({babi_qa.FeatureNames.QUESTION: question}
                                    )[ babi_qa.FeatureNames.QUESTION]

    story_modality = self._problem_hparams.input_modality[
                  babi_qa.FeatureNames.STORY]
    question_modality = self._problem_hparams.input_modality[
                  babi_qa.FeatureNames.QUESTION]


    with tf.variable_scope(story_modality.name):
      story = story_modality.bottom_sharded(story, dp)


    with tf.variable_scope(question_modality.name,
                    reuse=(story_modality.name == question_modality.name)):
      question = question_modality.bottom_sharded(question, dp)

    with tf.variable_scope("body"):
      if target_modality.is_class_modality:
        encoder_output = dp(self.encode, story, question,
                              features["target_space_id"], hparams)
      else:
        encoder_output, encoder_decoder_attention_bias = dp(self.encode, story,
          question, features["target_space_id"],hparams,features=features)
        encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

      encoder_output = encoder_output[0]


    if hparams.pos == "timing":
      timing_signal = common_attention.get_timing_signal_1d(decode_length + 1,
        hparams.hidden_size)

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the
      decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
        lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
        decode_length)


    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(self.decode, targets, cache.get("encoder_output"),
          cache.get("encoder_decoder_attention_bias"), bias, hparams, cache,
          nonpadding=features_to_nonpadding(features, "targets")
                          )

      with tf.variable_scope(target_modality.name):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      ret = tf.squeeze(logits, axis=[1, 2, 3])
      return ret, cache

    def labels_to_logits_fn(unused_ids, unused_i, cache):
      """Go from labels to logits"""
      with tf.variable_scope("body"):
        body_outputs = dp(tf.expand_dims, cache.get("encoder_output"), 2)

      with tf.variable_scope(target_modality.name):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      ret = tf.squeeze(logits, axis=[1, 2, 3])
      return ret, cache

    if target_modality.is_class_modality:
      ret = transformer.fast_decode(encoder_output=encoder_output,
        encoder_decoder_attention_bias=None,
        symbols_to_logits_fn=labels_to_logits_fn, hparams=hparams,
        decode_length=decode_length,
        vocab_size=target_modality.top_dimensionality, beam_size=beam_size,
        top_beams=top_beams, alpha=alpha, batch_size=batch_size)

    else:
      ret = transformer.fast_decode(encoder_output=encoder_output,
        encoder_decoder_attention_bias=encoder_decoder_attention_bias,
        symbols_to_logits_fn=symbols_to_logits_fn, hparams=hparams,
        decode_length=decode_length,
        vocab_size=target_modality.top_dimensionality, beam_size=beam_size,
        top_beams=top_beams, alpha=alpha, batch_size=batch_size)

    return ret
Example #25
0
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams

    inputs = features["inputs"]
    batch_size = common_layers.shape_list(inputs)[0]
    target_modality = self._problem_hparams.target_modality
    if target_modality.is_class_modality:
      decode_length = 1
    else:
      decode_length = common_layers.shape_list(inputs)[1] + decode_length

    # TODO(llion): Clean up this reshaping logic.
    inputs = tf.expand_dims(inputs, axis=1)
    if len(inputs.shape) < 5:
      inputs = tf.expand_dims(inputs, axis=4)
    s = common_layers.shape_list(inputs)
    inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
    # _shard_features called to ensure that the variable names match
    inputs = self._shard_features({"inputs": inputs})["inputs"]
    input_modality = self._problem_hparams.input_modality["inputs"]
    with tf.variable_scope(input_modality.name):
      inputs = input_modality.bottom_sharded(inputs, dp)
    with tf.variable_scope("body"):
      encoder_output, encoder_decoder_attention_bias = dp(
          self.encode, inputs, features["target_space_id"], hparams,
          features=features)
    encoder_output = encoder_output[0]
    encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

    if hparams.pos == "timing":
      timing_signal = common_attention.get_timing_signal_1d(
          decode_length + 1, hparams.hidden_size)

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      # TODO(llion): Explain! Is this even needed?
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
          decode_length)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
            self.decode, targets, cache["encoder_output"],
            cache["encoder_decoder_attention_bias"], bias, hparams, cache,
            nonpadding=_features_to_nonpadding(features, "targets"))

      with tf.variable_scope(target_modality.name):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      return tf.squeeze(logits, axis=[1, 2, 3]), cache

    key_channels = hparams.attention_key_channels or hparams.hidden_size
    value_channels = hparams.attention_value_channels or hparams.hidden_size
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

    cache = {
        "layer_%d" % layer: {
            "k": tf.zeros([batch_size, 0, key_channels]),
            "v": tf.zeros([batch_size, 0, value_channels]),
        }
        for layer in range(num_layers)
    }

    # Set 2nd dim to None since it's not invariant in the tf.while_loop
    # Note: Tensor.set_shape() does not work here since it merges shape info.
    # TODO(llion); Find a more robust solution.
    # pylint: disable=protected-access
    if not context.in_eager_mode():
      for layer in cache:
        cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels])
        cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels])
    # pylint: enable=protected-access
    cache["encoder_output"] = encoder_output
    cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

    if beam_size > 1:  # Beam Search
      target_modality = (
          self._hparams.problems[self._problem_idx].target_modality)
      vocab_size = target_modality.top_dimensionality
      initial_ids = tf.zeros([batch_size], dtype=tf.int32)
      decoded_ids, scores = beam_search.beam_search(
          symbols_to_logits_fn,
          initial_ids,
          beam_size,
          decode_length,
          vocab_size,
          alpha,
          states=cache,
          stop_early=(top_beams == 1))

      if top_beams == 1:
        decoded_ids = decoded_ids[:, 0, 1:]
      else:
        decoded_ids = decoded_ids[:, :top_beams, 1:]
    else:  # Greedy

      def inner_loop(i, next_id, decoded_ids, cache):
        logits, cache = symbols_to_logits_fn(next_id, i, cache)
        temperature = (0.0 if hparams.sampling_method == "argmax" else
                       hparams.sampling_temp)
        next_id = tf.expand_dims(
            common_layers.sample_with_temperature(logits, temperature), axis=1)
        decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
        return i + 1, next_id, decoded_ids, cache

      decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
      scores = None
      next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
      _, _, decoded_ids, _ = tf.while_loop(
          # TODO(llion): Early stopping.
          lambda i, *_: tf.less(i, decode_length),
          inner_loop,
          [tf.constant(0), next_id, decoded_ids, cache],
          shape_invariants=[
              tf.TensorShape([]),
              tf.TensorShape([None, None]),
              tf.TensorShape([None, None]),
              nest.map_structure(lambda t: tf.TensorShape(t.shape), cache),
          ])

    return decoded_ids, scores
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.modality["targets"]
        if "targets_segmentation" in features:
            raise NotImplementedError(
                "Decoding not supported on packed datasets "
                " If you want to decode from a dataset, use the non-packed version"
                " of the dataset when decoding.")
        if self.has_input:
            inputs = features["inputs"]
            if target_modality.is_class_modality:
                decode_length = 1
            else:
                decode_length = (common_layers.shape_list(inputs)[1] +
                                 features.get("decode_length", decode_length))

            contexts = {}
            for feature_name in features:
                if 'context' in feature_name and 'raw' not in feature_name:
                    contexts[feature_name] = features[feature_name]

            inputs = tf.expand_dims(inputs, axis=1)
            if len(inputs.shape) < 5:
                inputs = tf.expand_dims(inputs, axis=4)
            s = common_layers.shape_list(inputs)
            batch_size = s[0]
            inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
            # _shard_features called to ensure that the variable names match
            inputs = self._shard_features({"inputs": inputs})["inputs"]
            input_modality = self._problem_hparams.modality["inputs"]

            context_modality = {}
            for context_name in contexts:
                if context_name in self._problem_hparams.modality:
                    context_modality[
                        context_name] = self._problem_hparams.modality[
                            context_name]
                else:
                    context_modality[context_name] = input_modality

            with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE):
                inputs = input_modality.bottom_sharded(inputs, dp)

            for feature_name in contexts:
                with tf.variable_scope(context_modality[feature_name].name,
                                       reuse=tf.AUTO_REUSE):
                    contexts[feature_name] = context_modality[
                        feature_name].bottom_sharded(contexts[feature_name],
                                                     dp)

            contexts_list = [
                contexts[feature_name][0] for feature_name in contexts
            ]
            contexts = tf.concat(contexts_list, axis=1)
            inputs = [tf.concat([contexts, inputs[0]], axis=1)]

            with tf.variable_scope("body"):
                encoder_output, encoder_decoder_attention_bias = dp(
                    self.encode,
                    inputs,
                    features["target_space_id"],
                    hparams,
                    features=features)
            encoder_output = encoder_output[0]
            encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
            partial_targets = None
        else:
            # The problem has no inputs.
            encoder_output = None
            encoder_decoder_attention_bias = None

            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs")
            if partial_targets is None:
                partial_targets = features["targets"]
            assert partial_targets is not None
            partial_targets = common_layers.expand_squeeze_to_nd(
                partial_targets, 2)
            partial_targets = tf.to_int64(partial_targets)
            partial_targets_shape = common_layers.shape_list(partial_targets)
            partial_targets_length = partial_targets_shape[1]
            decode_length = (partial_targets_length +
                             features.get("decode_length", decode_length))
            batch_size = partial_targets_shape[0]

        if hparams.pos == "timing":
            positional_encoding = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)
        elif hparams.pos == "emb":
            positional_encoding = common_attention.add_positional_embedding(
                tf.zeros([1, decode_length, hparams.hidden_size]),
                hparams.max_length, "body/targets_positional_embedding", None)
        else:
            positional_encoding = None

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.
            This includes:
              - Embedding the ids.
              - Flattening to 3D tensor.
              - Optionally adding timing signals.
            Args:
              targets: inputs ids to the decoder. [batch_size, 1]
              i: scalar, Step number of the decoding loop.
            Returns:
              Processed targets [batch_size, 1, hidden_dim]
            """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if positional_encoding is not None:
                targets += positional_encoding[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode,
                                  targets,
                                  cache.get("encoder_output"),
                                  cache.get("encoder_decoder_attention_bias"),
                                  bias,
                                  hparams,
                                  cache,
                                  nonpadding=features_to_nonpadding(
                                      features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            ret = tf.squeeze(logits, axis=[1, 2, 3])
            if partial_targets is not None:
                # If the position is within the given partial targets, we alter the
                # logits to always return those values.
                # A faster approach would be to process the partial targets in one
                # iteration in order to fill the corresponding parts of the cache.
                # This would require broader changes, though.
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size, 0.0, -1e9)

                ret = tf.cond(tf.less(i, partial_targets_length),
                              forced_logits, lambda: ret)
            return ret, cache

        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size,
            force_decode_length=self._decode_hparams.force_decode_length)
        if partial_targets is not None:
            if beam_size <= 1 or top_beams <= 1:
                ret["outputs"] = ret["outputs"][:, partial_targets_length:]
            else:
                ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
        return ret
Example #27
0
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    if features and "inputs_segmentation" in features:
        # Packed dataset.  Keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        targets_segmentation = features["targets_segmentation"]
        encoder_self_attention_bias = common_attention.attention_bias_same_segment(
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = (
            common_attention.attention_bias_same_segment(
                targets_segmentation, inputs_segmentation))
    else:
        # Usual case - not a packed dataset.
        encoder_padding = common_attention.embedding_to_padding(encoder_input)
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(inputs)[1])
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(target_space,
                                               32,
                                               ishape_static[-1],
                                               name="target_space_embedding")
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
    #if hparams.pos == "timing":
    #  if inputs_position is not None:
    #    encoder_input = common_attention.add_timing_signal_1d_given_position(
    #        encoder_input, inputs_position)
    #  else:
    #    encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    raw_encoder_input = tf.squeeze(features['inputs_raw'], axis=[-2, -1])
    pos_signals = generate_positional_signals(raw_encoder_input, hparams)
    pos_embeddings = generate_positional_embeddings(pos_signals,
                                                    hparams.encoder_pos,
                                                    hparams)
    if "sum" in hparams.encoder_pos_integration:
        encoder_input = encoder_input + pos_embeddings
    elif "ffn" in hparams.encoder_pos_integration:
        with tf.variable_scope("encoder_pos_ffn"):
            encoder_input = tf.concat([encoder_input, pos_embeddings], axis=2)
            encoder_input = transformer_ffn_layer(encoder_input,
                                                  hparams,
                                                  conv_padding="SAME")
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Example #28
0
    def _greedy_infer(self, features, decode_length, last_position_only=True):
        """Fast version of greedy decoding.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      last_position_only: MUST be true for fast decoding!

    Returns:
       samples: [batch_size, input_length + decode_length]
       logits: Not returned
       losses: Not returned

    Raises:
      ValueError: If last_position_only if False
      NotImplementedError: If there are multiple data shards.
    """
        if not last_position_only:
            raise ValueError(
                "Fast decoding only deals with the last positions!")
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams

        inputs = features["inputs"]
        batch_size = tf.shape(inputs)[0]
        target_modality = self._problem_hparams.target_modality
        if t2t_model.is_class_modality(target_modality):
            decode_length = 1
        else:
            decode_length = tf.shape(inputs)[1] + decode_length

        # TODO(llion): Clean up this reshaping logic.
        inputs = tf.expand_dims(inputs, axis=1)
        if len(inputs.shape) < 5:
            inputs = tf.expand_dims(inputs, axis=4)
        s = tf.shape(inputs)
        inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.input_modality["inputs"]
        with tf.variable_scope(input_modality.name):
            inputs = input_modality.bottom_sharded(inputs, dp)
        with tf.variable_scope("body"):
            encoder_output, encoder_decoder_attention_bias = dp(
                self.encode, inputs, features["target_space_id"], hparams)

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode, targets, encoder_output[0],
                                  encoder_decoder_attention_bias[0], bias,
                                  hparams, cache)

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            return tf.squeeze(logits, axis=[1, 2, 3])

        def inner_loop(i, next_id, decoded_ids, cache):
            logits = symbols_to_logits_fn(next_id, i, cache)
            next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1)
            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return i + 1, next_id, decoded_ids, cache

        key_channels = hparams.attention_key_channels or hparams.hidden_size
        value_channels = hparams.attention_value_channels or hparams.hidden_size
        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

        cache = {
            "layer_%d" % layer: {
                "k": tf.zeros([batch_size, 0, key_channels]),
                "v": tf.zeros([batch_size, 0, value_channels]),
            }
            for layer in range(num_layers)
        }
        decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
        next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
        _, _, decoded_ids, _ = tf.while_loop(
            # TODO(llion): Early stopping.
            lambda i, *_: tf.less(i, decode_length),
            inner_loop,
            [tf.constant(0), next_id, decoded_ids, cache],
            shape_invariants=[
                tf.TensorShape([]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                {
                    "layer_%d" % layer: {
                        "k": tf.TensorShape([None, None, key_channels]),
                        "v": tf.TensorShape([None, None, value_channels]),
                    }
                    for layer in range(num_layers)
                }
            ])

        return decoded_ids, None, None
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        #dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.modality["targets"]

        inputs = features["inputs"]

        decode_length = (common_layers.shape_list(inputs)[1] +
                         features.get("decode_length", decode_length))

        #inputs = tf.expand_dims(inputs, axis=1)
        #if len(inputs.shape) < 5:
        #    inputs = tf.expand_dims(inputs, axis=4)

        s = common_layers.shape_list(inputs)
        batch_size = s[0]
        #inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        #inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.modality["inputs"]
        context_modality = {}

        contexts = {}
        for feature_name in features:
            if 'context' in feature_name and 'raw' not in feature_name:
                contexts[feature_name] = features[feature_name]

        for context_name in contexts:
            if context_name in self._problem_hparams.modality:
                context_modality[
                    context_name] = self._problem_hparams.modality[
                        context_name]
            else:
                context_modality[context_name] = input_modality

        with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE):
            inputs = input_modality.bottom(inputs)
            for context_name in contexts:
                contexts[context_name] = context_modality[context_name].bottom(
                    contexts[context_name])

        with tf.variable_scope("body", reuse=tf.AUTO_REUSE):
            encoder_output, encoder_decoder_attention_bias = self.encode(
                inputs,
                contexts,
                features["target_space_id"],
                hparams,
                features=features)
        #encoder_output = encoder_output[0]
        #encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
        partial_targets = None

        if hparams.pos == "timing":
            positional_encoding = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)
        elif hparams.pos == "emb":
            positional_encoding = common_attention.add_positional_embedding(
                tf.zeros([1, decode_length + 1, hparams.hidden_size]),
                hparams.max_length, "targets_positional_embedding", None)
        else:
            positional_encoding = None

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.
            This includes:
              - Embedding the ids.
              - Flattening to 3D tensor.
              - Optionally adding timing signals.
            Args:
              targets: inputs ids to the decoder. [batch_size, 1]
              i: scalar, Step number of the decoding loop.
            Returns:
              Processed targets [batch_size, 1, hidden_dim]
            """
            # _shard_features called to ensure that the variable names match
            #targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom(targets)
            targets = common_layers.flatten4d3d(targets)

            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if positional_encoding is not None:
                targets += positional_encoding[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = self.decode(
                    targets,
                    cache.get("encoder_output"),
                    cache.get("encoder_decoder_attention_bias"),
                    bias,
                    hparams,
                    cache,
                    nonpadding=features_to_nonpadding(features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top(body_outputs, None)

            ret = tf.squeeze(logits, axis=[1, 2, 3])
            return ret, cache

        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size,
            force_decode_length=self._decode_hparams.force_decode_length)

        return ret
Example #30
0
    def _fast_decode(
        self,
        features,
        decode_length,
        beam_size=1,
        top_beams=1,
        alpha=1.0,
        preprocess_targets_method=None,
    ):
        if self._num_datashards != 1:
            raise NotImplementedError(
                'Fast decoding only supports a single shard.')
        dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.modality['targets']
        target_vocab_size = self._problem_hparams.vocab_size['targets']
        if target_vocab_size is not None and hasattr(hparams, 'vocab_divisor'):
            target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor

        target_tag_modality = self._problem_hparams.modality[
            'targets_error_tag']
        target_tag_vocab_size = self._problem_hparams.vocab_size[
            'targets_error_tag']
        if target_tag_vocab_size is not None and hasattr(
                hparams, 'vocab_divisor'):
            target_tag_vocab_size += (
                -target_tag_vocab_size) % hparams.vocab_divisor

        if 'targets_segmentation' in features:
            raise NotImplementedError(
                'Decoding not supported on packed datasets '
                ' If you want to decode from a dataset, use the non-packed version'
                ' of the dataset when decoding.')
        if self.has_input:
            inputs_shape = common_layers.shape_list(features['inputs'])
            if (target_modality == modalities.ModalityType.CLASS_LABEL
                    or self._problem_hparams.get('regression_targets')):
                decode_length = 1
            else:
                decode_length = inputs_shape[1] + features.get(
                    'decode_length', decode_length)
            batch_size = inputs_shape[0]
            inputs = self._prepare_inputs_for_decode(features)
            with tf.variable_scope('body'):
                encoder_output, encoder_decoder_attention_bias = dp(
                    self.encode,
                    inputs,
                    features['target_space_id'],
                    hparams,
                    features=features,
                )
            encoder_output = encoder_output[0]
            encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
            partial_targets = features.get('partial_targets')
        else:
            encoder_output = None
            encoder_decoder_attention_bias = None
            partial_targets = features.get('inputs')
            if partial_targets is None:
                partial_targets = features['targets']
            assert partial_targets is not None

        if partial_targets is not None:
            partial_targets = common_layers.expand_squeeze_to_nd(
                partial_targets, 2)
            partial_targets = tf.to_int64(partial_targets)
            partial_targets_shape = common_layers.shape_list(partial_targets)
            partial_targets_length = partial_targets_shape[1]
            decode_length = partial_targets_length + features.get(
                'decode_length', decode_length)
            batch_size = partial_targets_shape[0]

        if hparams.pos == 'timing':
            positional_encoding = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)
        elif hparams.pos == 'timing_from_features':
            positional_encoding = common_attention.add_timing_signals_from_features(
                tf.zeros([1, decode_length, hparams.hidden_size]),
                features,
                hparams.position_features,
            )
        elif hparams.pos == 'emb':
            positional_encoding = common_attention.add_positional_embedding(
                tf.zeros([1, decode_length, hparams.hidden_size]),
                hparams.max_length,
                'body/targets_positional_embedding',
                None,
            )
        else:
            positional_encoding = None

        def preprocess_targets(targets, i):
            targets = self._shard_features({'targets': targets})['targets']
            modality_name = hparams.name.get(
                'targets',
                modalities.get_name(target_modality))(hparams,
                                                      target_vocab_size)
            with tf.variable_scope(modality_name + '/targets'):
                bottom = hparams.bottom.get(
                    'targets', modalities.get_targets_bottom(target_modality))
                targets = dp(bottom, targets, hparams, target_vocab_size)[0]
            targets = common_layers.flatten4d3d(targets)

            if not self.get_decode_start_id():
                targets = tf.cond(
                    tf.equal(i, 0),
                    lambda: tf.zeros_like(targets),
                    lambda: targets,
                )

            if positional_encoding is not None:
                targets += positional_encoding[:, i:i + 1]
            return targets

        def preprocess_targets_tag_method(targets, i):
            targets = self._shard_features({'targets_error_tag':
                                            targets})['targets_error_tag']
            modality_name = hparams.name.get(
                'targets_error_tag', modalities.get_name(target_tag_modality))(
                    hparams, target_tag_vocab_size)
            with tf.variable_scope(modality_name + '/targets_error_tag'):
                bottom = hparams.bottom.get(
                    'targets_error_tag',
                    modalities.get_targets_bottom(target_tag_modality),
                )
                targets = dp(bottom, targets, hparams,
                             target_tag_vocab_size)[0]
            targets = common_layers.flatten4d3d(targets)
            if not self.get_decode_start_id():
                targets = tf.cond(
                    tf.equal(i, 0),
                    lambda: tf.zeros_like(targets),
                    lambda: targets,
                )

            if positional_encoding is not None:
                targets += positional_encoding[:, i:i + 1]
            return targets

        decoder_self_attention_bias = common_attention.attention_bias_lower_triangle(
            decode_length)
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)
        att_cache = {'attention_history': {}}
        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
        if encoder_output is not None:
            att_batch_size, enc_seq_length = common_layers.shape_list(
                encoder_output)[0:2]
            for layer in range(num_layers):
                att_cache['attention_history']['layer_%d' % layer] = tf.zeros(
                    [att_batch_size, hparams.num_heads, 0, enc_seq_length])

        def update_decoder_attention_history(cache):
            for k in [
                    x for x in self.attention_weights
                    if 'decoder' in x and 'self' not in x and 'logits' not in x
            ]:
                idx = k.find('layer_')
                if idx < 0:
                    continue
                # Get layer number from the string name.
                layer_nbr = k[idx + 6:]
                idx = 0
                while (idx + 1 < len(layer_nbr)
                       and layer_nbr[:idx + 1].isdigit()):
                    idx += 1
                layer_nbr = 'layer_%d' % int(layer_nbr[:idx])
                if layer_nbr in cache['attention_history']:
                    cache['attention_history'][layer_nbr] = tf.concat(
                        [
                            cache['attention_history'][layer_nbr],
                            self.attention_weights[k],
                        ],
                        axis=2,
                    )

        if not preprocess_targets_method:
            preprocess_targets_method = preprocess_targets

        def symbols_to_logits_fn(ids, ids_tag, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets_method(targets, i)

            ids_tag = ids_tag[:, -1:]
            targets_tag = tf.expand_dims(tf.expand_dims(ids_tag, axis=2),
                                         axis=3)
            targets_tag = preprocess_targets_tag_method(targets_tag, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope('body'):
                with tf.variable_scope('edit_ops_layer'):
                    with tf.variable_scope('ffn'):
                        x = targets
                        preproc = lambda z: common_layers.layer_preprocess(
                            z, hparams, layer_collection=None)
                        layer_inputs = [
                            tf.concat(preproc(x), axis=0),
                            tf.concat(preproc(targets_tag), axis=0),
                        ]
                        y = transformer_layers.transformer_ffn_layer(
                            tf.concat(layer_inputs, axis=2),
                            hparams,
                            conv_padding='LEFT',
                            nonpadding_mask=features_to_nonpadding(
                                features, 'targets'),
                            losses=None,
                            cache=cache,
                            decode_loop_step=None,
                            layer_collection=None,
                        )
                        targets = common_layers.layer_postprocess(
                            x, y, hparams)

                if hparams.middle_prediction:
                    num_decoder_layers = (hparams.num_decoder_layers
                                          or hparams.num_hidden_layers)
                    hparams.num_decoder_layers = int(
                        num_decoder_layers /
                        hparams.middle_prediction_layer_factor)

                body_outputs = dp(
                    self.decode,
                    targets,
                    cache.get('encoder_output'),
                    cache.get('encoder_decoder_attention_bias'),
                    bias,
                    hparams,
                    cache,
                    nonpadding=features_to_nonpadding(features, 'targets'),
                )[0]

                body_outputs, logits_tag = dp(
                    self._prediction_cascade_predict,
                    hparams,
                    features_to_nonpadding(features, 'targets'),
                    cache.get('encoder_decoder_attention_bias'),
                    cache.get('encoder_output'),
                    body_outputs,
                )
                logits_tag = logits_tag[0]['targets_error_tag']
                if hparams.middle_prediction:
                    with tf.variable_scope('after_prediction'):
                        body_outputs = dp(
                            self.decode,
                            targets + body_outputs[0],
                            cache.get('encoder_output'),
                            cache.get('encoder_decoder_attention_bias'),
                            bias,
                            hparams,
                            cache,
                            nonpadding=features_to_nonpadding(
                                features, 'targets'),
                        )

            update_decoder_attention_history(cache)

            modality_name = hparams.name.get(
                'targets',
                modalities.get_name(target_modality))(hparams,
                                                      target_vocab_size)
            with tf.variable_scope('targets/' + modality_name):
                top = hparams.top.get('targets',
                                      modalities.get_top(target_modality))
                logits = dp(top, body_outputs, None, hparams,
                            target_vocab_size)[0]

            ret = tf.squeeze(logits, axis=[1, 2])
            if partial_targets is not None:
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size,
                        0.0,
                        -1e9,
                    )

                ret = tf.cond(
                    tf.less(i, partial_targets_length),
                    forced_logits,
                    lambda: ret,
                )
            logits_tag = tf.squeeze(logits_tag, axis=[1])
            return ret, logits_tag, cache

        sos_id = self.get_decode_start_id() or 0
        eos_id = self.get_decode_end_id() or beam_search.EOS_ID
        temperature = features.get('sampling_temp',
                                   getattr(hparams, 'sampling_temp', 0.0))
        top_k = features.get('sampling_keep_top_k',
                             getattr(hparams, 'sampling_keep_top_k', -1))
        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_vocab_size,
            init_cache_fn=_init_transformer_cache,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size,
            force_decode_length=self._decode_hparams.force_decode_length,
            sos_id=sos_id,
            eos_id=eos_id,
            sampling_temperature=temperature,
            top_k=top_k,
            cache=att_cache,
        )
        if partial_targets is not None:
            if beam_size <= 1 or top_beams <= 1:
                ret['outputs'] = ret['outputs'][:, partial_targets_length:]
            else:
                ret['outputs'] = ret['outputs'][:, :, partial_targets_length:]
        return ret
Example #31
0
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        """
		Fast decoding.
		Implements both greedy and beam search decoding, uses beam search iff
		beam_size > 1, otherwise beam search related arguments are ignored.
		Args:
			features: a map of string to model  features.
			decode_length: an integer.  How many additional timesteps to decode.
			beam_size: number of beams.
			top_beams: an integer. How many of the beams to return.
			alpha: Float that controls the length penalty. larger the alpha, stronger
			the preference for slonger translations.
		Returns:
			 samples: an integer `Tensor`. Top samples from the beam search
		Raises:
			NotImplementedError: If there are multiple data shards.
		"""
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams

        inputs = features["inputs"]
        batch_size = common_layers.shape_list(inputs)[0]
        target_modality = self._problem_hparams.target_modality
        if target_modality.is_class_modality:
            decode_length = 1
        else:
            decode_length = common_layers.shape_list(inputs)[1] + decode_length

        # TODO(llion): Clean up this reshaping logic.
        inputs = tf.expand_dims(inputs, axis=1)
        if len(inputs.shape) < 5:
            inputs = tf.expand_dims(inputs, axis=4)
        s = common_layers.shape_list(inputs)
        inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.input_modality["inputs"]
        with tf.variable_scope(input_modality.name):
            inputs = input_modality.bottom_sharded(inputs, dp)
        with tf.variable_scope("body"):
            encoder_output, encoder_decoder_attention_bias = dp(
                self.encode,
                inputs,
                features["target_space_id"],
                hparams,
                features=features)
        encoder_output = encoder_output[0]
        encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.
			This includes:
			- Embedding the ids.
			- Flattening to 3D tensor.
			- Optionally adding timing signals.
			Args:
			targets: inputs ids to the decoder. [batch_size, 1]
			i: scalar, Step number of the decoding loop.
			Returns:
			Processed targets [batch_size, 1, hidden_dim]
			"""
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(
                    self.decode,
                    targets,
                    cache["encoder_output"],
                    cache["encoder_decoder_attention_bias"],
                    bias,
                    hparams,
                    cache,
                    nonpadding=transformer._features_to_nonpadding(
                        features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            return tf.squeeze(logits, axis=[1, 2, 3]), cache

        key_channels = hparams.attention_key_channels or hparams.hidden_size
        value_channels = hparams.attention_value_channels or hparams.hidden_size
        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

        cache = {
            "layer_%d" % layer: {
                "k": tf.zeros([batch_size, 0, key_channels]),
                "v": tf.zeros([batch_size, 0, value_channels]),
            }
            for layer in range(num_layers)
        }

        # Set 2nd dim to None since it's not invariant in the tf.while_loop
        # Note: Tensor.set_shape() does not work here since it merges shape info.
        # TODO(llion); Find a more robust solution.
        # pylint: disable=protected-access
        if not context.in_eager_mode():
            for layer in cache:
                cache[layer]["k"]._shape = tf.TensorShape(
                    [None, None, key_channels])
                cache[layer]["v"]._shape = tf.TensorShape(
                    [None, None, value_channels])
        # pylint: enable=protected-access
        cache["encoder_output"] = encoder_output
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        if beam_size > 1:  # Beam Search
            target_modality = (
                self._hparams.problems[self._problem_idx].target_modality)
            vocab_size = target_modality.top_dimensionality
            initial_ids = tf.zeros([batch_size], dtype=tf.int32)
            decoded_ids, scores = beam_search.beam_search(
                symbols_to_logits_fn,
                initial_ids,
                beam_size,
                decode_length,
                vocab_size,
                alpha,
                states=cache,
                stop_early=(top_beams == 1))

            decoded_ids = decoded_ids[:, :, 1:]

            # do roulette wheel selection or inverse roulette wheel selection
            if self._hparams.roulette == "Normal" or self._hparams.roulette == "Inverse":
                if self._hparams.roulette == "Normal":
                    probabilities = tf.pow(tf.constant(2.0), scores)
                    start = 0
                else:
                    probabilities = tf.subtract(
                        tf.constant(1.0), tf.pow(tf.constant(2.0), scores))
                    start = beam_size - self._hparams.roulette_beam_size

                summ = tf.reduce_sum(probabilities)
                ex_probs = tf.divide(probabilities, summ)
                #ex_probs=tf.nn.softmax(probabilities)

                # sample a number between 0 and 1
                wheel = tf.random_uniform([1])
                upper_bound = tf.constant(0.0)

                # change this as well if using inverse
                for i in range(start, self._hparams.roulette_beam_size):
                    upper_bound = tf.add(ex_probs[:, i], upper_bound)
                    truthValue = tf.squeeze(
                        tf.logical_and(wheel >= upper_bound - ex_probs[:, i],
                                       wheel <= upper_bound))
                    decoded_ids, scores, i = tf.cond(
                        truthValue, lambda:
                        (decoded_ids[:, i, :], scores[:, i], beam_size),
                        lambda: (decoded_ids, scores, i))

        else:  # Greedy

            def inner_loop(i, next_id, decoded_ids, cache):
                logits, cache = symbols_to_logits_fn(next_id, i, cache)
                temperature = (0.0 if hparams.sampling_method == "argmax" else
                               hparams.sampling_temp)
                next_id = tf.expand_dims(common_layers.sample_with_temperature(
                    logits, temperature),
                                         axis=1)
                decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
                return i + 1, next_id, decoded_ids, cache

            decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
            scores = None
            next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
            _, _, decoded_ids, _ = tf.while_loop(
                # TODO(llion): Early stopping.
                lambda i, *_: tf.less(i, decode_length),
                inner_loop,
                [tf.constant(0), next_id, decoded_ids, cache],
                shape_invariants=[
                    tf.TensorShape([]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                    nest.map_structure(lambda t: tf.TensorShape(t.shape),
                                       cache),
                ])

        return decoded_ids, scores
Example #32
0
    def _fast_decode_tpu(self, features, decode_length, beam_size=1):
        """Fast decoding.

        Implements only greedy decoding on TPU.

        Args:
          features: A map of string to model features.
          decode_length: An integer, how many additional timesteps to decode.
          beam_size: An integer, number of beams.

        Returns:
          A dict of decoding results {
              "outputs": integer `Tensor` of decoded ids of shape
                  [batch_size, <= decode_length]
              "scores": decoding log probs from the beam search,
                  None if using greedy decoding (beam_size=1)
          }.

        Raises:
          NotImplementedError: If there are multiple data shards or beam_size > 1.
        """
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        if "targets_segmentation" in features:
            raise NotImplementedError(
                "Decoding not supported on packed datasets "
                " If you want to decode from a dataset, use the non-packed version"
                " of the dataset when decoding.")
        dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.target_modality

        if self.has_input:
            inputs = features["inputs"]
            if target_modality.is_class_modality:
                decode_length = 1
            else:
                decode_length = (common_layers.shape_list(inputs)[1] +
                                 features.get("decode_length", decode_length))

            # TODO(llion): Clean up this reshaping logic.
            inputs = tf.expand_dims(inputs, axis=1)
            if len(inputs.shape) < 5:
                inputs = tf.expand_dims(inputs, axis=4)
            s = common_layers.shape_list(inputs)
            batch_size = s[0]
            inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
            # _shard_features called to ensure that the variable names match
            inputs = self._shard_features({"inputs": inputs})["inputs"]
            input_modality = self._problem_hparams.input_modality["inputs"]
            with tf.variable_scope(input_modality.name):
                inputs = input_modality.bottom_sharded(inputs, dp)
            with tf.variable_scope("body"):
                encoder_output, encoder_decoder_attention_bias = dp(
                    self.encode,
                    inputs,
                    features["target_space_id"],
                    hparams,
                    features=features)
            encoder_output = encoder_output[0]
            encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
            partial_targets = None
        else:
            # The problem has no inputs.
            encoder_output = None
            encoder_decoder_attention_bias = None

            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs")
            if partial_targets is None:
                partial_targets = features["targets"]
            assert partial_targets is not None
            partial_targets = common_layers.expand_squeeze_to_nd(
                partial_targets, 2)
            partial_targets = tf.to_int64(partial_targets)
            partial_targets_shape = common_layers.shape_list(partial_targets)
            partial_targets_length = partial_targets_shape[1]
            decode_length = (partial_targets_length +
                             features.get("decode_length", decode_length))
            batch_size = partial_targets_shape[0]

        if hparams.pos == "timing":
            positional_encoding = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)
        elif hparams.pos == "emb":
            positional_encoding = common_attention.add_positional_embedding(
                tf.zeros([1, decode_length + 1, hparams.hidden_size]),
                hparams.max_length, "body/targets_positional_embedding", None)
        else:
            positional_encoding = None

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

            This includes:
              - Embedding the ids.
              - Flattening to 3D tensor.
              - Optionally adding timing signals.

            Args:
              targets: A tensor, inputs ids to the decoder. [batch_size, 1].
              i: An integer, Step number of the decoding loop.

            Returns:
              A tensor, processed targets [batch_size, 1, hidden_dim].
            """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if positional_encoding is not None:
                positional_encoding_shape = positional_encoding.shape.as_list()
                targets += tf.slice(positional_encoding, [0, i, 0], [
                    positional_encoding_shape[0], 1,
                    positional_encoding_shape[2]
                ])
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_tpu_fn(ids, i, cache):
            """Go from ids to logits for next symbol on TPU.

            Args:
              ids: A tensor, symbol IDs.
              i: An integer, step number of the decoding loop. Only used for inference
                  on TPU.
              cache: A dict, containing tensors which are the results of previous
                  attentions, used for fast decoding.

            Returns:
              ret: A tensor, computed logits.
              cache: A dict, containing tensors which are the results of previous
                  attentions, used for fast decoding.
            """
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias_shape = decoder_self_attention_bias.shape.as_list()
            bias = tf.slice(decoder_self_attention_bias, [0, 0, i, 0],
                            [bias_shape[0], bias_shape[1], 1, bias_shape[3]])

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode,
                                  targets,
                                  cache.get("encoder_output"),
                                  cache.get("encoder_decoder_attention_bias"),
                                  bias,
                                  hparams,
                                  cache,
                                  i,
                                  nonpadding=features_to_nonpadding(
                                      features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            ret = tf.squeeze(logits, axis=[1, 2, 3])
            if partial_targets is not None:
                # If the position is within the given partial targets, we alter the
                # logits to always return those values.
                # A faster approach would be to process the partial targets in one
                # iteration in order to fill the corresponding parts of the cache.
                # This would require broader changes, though.
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(
                            tf.slice(partial_targets, [0, i],
                                     [partial_targets.shape.as_list()[0], 1]),
                            [beam_size]), vocab_size, 0.0, -1e9)

                ret = tf.cond(tf.less(i, partial_targets_length),
                              forced_logits, lambda: ret)
            return ret, cache

        ret = fast_decode_tpu(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_tpu_fn,
            hparams=hparams,
            decode_length=decode_length,
            beam_size=beam_size,
            batch_size=batch_size,
            force_decode_length=self._decode_hparams.force_decode_length)
        if partial_targets is not None:
            ret["outputs"] = ret["outputs"][:, partial_targets_length:]
        return ret