예제 #1
0
  def testNotGreedyBeamTwo(self):
    batch_size = 1
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                 [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                 [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

    def symbols_to_logits(ids):
      pos = tf.shape(ids)[1]
      logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
      return logits

    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1)

    with self.test_session():
      ids = final_ids.eval()
      probs = final_probs.eval()
    self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids)
    self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))
예제 #2
0
    def testGreedyWithCornerCase(self):
        batch_size = 1
        beam_size = 1
        vocab_size = 3
        decode_length = 2

        initial_ids = tf.constant([0] * batch_size)  # GO
        probabilities = tf.constant([[0.2, 0.1, 0.7], [0.4, 0.1, 0.5]])

        def symbols_to_logits(ids):
            pos = tf.shape(ids)[1]
            logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
            return logits

        final_ids, final_probs, _ = beam_search.beam_search(symbols_to_logits,
                                                            initial_ids,
                                                            beam_size,
                                                            decode_length,
                                                            vocab_size,
                                                            0.0,
                                                            eos_id=1)

        with self.test_session():
            ids = final_ids.eval()
            probs = final_probs.eval()
        self.assertAllEqual([[[0, 2, 2]]], ids)
        self.assertAllClose([[0.7 * 0.5]], np.exp(probs))
예제 #3
0
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams):
    """Sample from the latent space in the autoencoder."""
    def symbols_to_logits_fn(ids):
        """Go from ids to logits."""
        ids = tf.expand_dims(ids, axis=2)  # Ids start with added all-zeros.
        latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]])

        with tf.variable_scope(tf.get_variable_scope(), reuse=False):
            latents_dense = embed(
                tf.one_hot(latents_discrete, depth=2**hparams.bottleneck_bits))
            latents_pred = decode_transformer(inputs, ed, latents_dense,
                                              hparams, "extra")
            logits = tf.layers.dense(latents_pred,
                                     2**hparams.bottleneck_bits,
                                     name="extra_logits")
            current_output_position = common_layers.shape_list(ids)[1] - 1
            logits = logits[:, current_output_position, :, :]
        return tf.squeeze(logits, axis=[1])

    initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
    length = tf.shape(latents_dense_in)[1]
    ids, _ = beam_search.beam_search(symbols_to_logits_fn,
                                     initial_ids,
                                     beam_size=1,
                                     decode_length=length,
                                     vocab_size=2**hparams.bottleneck_bits,
                                     alpha=0.0,
                                     eos_id=-1,
                                     stop_early=False)

    res = tf.expand_dims(ids[:, 0, :], axis=2)  # Pick first beam.
    return res[:, 1:]  # Remove the added all-zeros from ids.
예제 #4
0
    def testStatesAfterLoop(self):
        batch_size = 1
        beam_size = 1
        vocab_size = 2
        decode_length = 3

        initial_ids = tf.constant([0] * batch_size)  # GO
        probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]])

        def symbols_to_logits(ids, _, states):
            pos = tf.shape(ids)[1] - 1
            logits = tf.to_float(tf.log(probabilities[pos, :]))
            states["state"] += 1
            return logits, states

        states = {
            "state": tf.zeros((batch_size, 1)),
        }
        states["state"] = tf.placeholder_with_default(states["state"],
                                                      shape=(None, 1))

        _, _, final_states = beam_search.beam_search(symbols_to_logits,
                                                     initial_ids,
                                                     beam_size,
                                                     decode_length,
                                                     vocab_size,
                                                     0.0,
                                                     eos_id=1,
                                                     states=states)

        with self.test_session() as sess:
            final_states = sess.run(final_states)
        self.assertAllEqual([[[2]]], final_states["state"])
예제 #5
0
  def testNotGreedyBeamTwo(self):
    batch_size = 1
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                 [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                 [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

    def symbols_to_logits(ids):
      pos = tf.shape(ids)[1]
      logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
      return logits

    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1)

    with self.test_session():
      ids = final_ids.eval()
      probs = final_probs.eval()
    self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids)
    self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))
예제 #6
0
  def testNotGreedyBeamTwoWithoutStopEarly(self):
    batch_size = 1
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                 [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                 [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

    def symbols_to_logits(ids):
      pos = tf.shape(ids)[1]
      logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
      return logits

    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1,
        stop_early=False)

    with self.test_session():
      ids = final_ids.eval()
      probs = final_probs.eval()
    # given stop_early = False, the algorithm will return all the beams
    # so we can test all of them here
    self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids)
    self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))
예제 #7
0
  def testGreedyWithCornerCase(self):
    batch_size = 1
    beam_size = 1
    vocab_size = 3
    decode_length = 2

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[0.2, 0.1, 0.7], [0.4, 0.1, 0.5]])

    def symbols_to_logits(ids):
      pos = tf.shape(ids)[1]
      logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
      return logits

    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1)

    with self.test_session():
      ids = final_ids.eval()
      probs = final_probs.eval()
    self.assertAllEqual([[[0, 2, 2]]], ids)
    self.assertAllClose([[0.7 * 0.5]], np.exp(probs))
예제 #8
0
    def testNotGreedyBeamTwoWithoutStopEarly(self):
        batch_size = 1
        beam_size = 2
        vocab_size = 3
        decode_length = 3

        initial_ids = tf.constant([0] * batch_size)  # GO
        probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                     [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                     [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

        def symbols_to_logits(ids):
            pos = tf.shape(ids)[1]
            logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
            return logits

        final_ids, final_probs, _ = beam_search.beam_search(symbols_to_logits,
                                                            initial_ids,
                                                            beam_size,
                                                            decode_length,
                                                            vocab_size,
                                                            0.0,
                                                            eos_id=1,
                                                            stop_early=False)

        with self.test_session():
            ids = final_ids.eval()
            probs = final_probs.eval()
        # given stop_early = False, the algorithm will return all the beams
        # so we can test all of them here
        self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids)
        self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))
예제 #9
0
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams):
  """Sample from the latent space in the autoencoder."""
  vocab_size = 2**hparams.z_size
  beam_size = 1  # TODO(lukaszkaiser): larger beam sizes seem to work bad.
  inputs = tf.tile(inputs, [beam_size, 1, 1])
  ed = tf.tile(ed, [beam_size, 1, 1, 1])

  def symbols_to_logits_fn(ids):
    """Go from ids to logits."""
    ids = tf.expand_dims(ids, axis=2)  # Ids start with added all-zeros.
    latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]])

    with tf.variable_scope(tf.get_variable_scope(), reuse=False):
      latents_dense = embed(latents_discrete)
      latents_pred = decode_transformer(
          inputs, ed, latents_dense, hparams, "extra")
      logits = tf.layers.dense(latents_pred, vocab_size, name="extra_logits")
      current_output_position = common_layers.shape_list(ids)[1] - 1
      logits = logits[:, current_output_position, :, :]
    return tf.squeeze(logits, axis=[1])

  initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
  length = tf.shape(latents_dense_in)[1]
  ids, _ = beam_search.beam_search(
      symbols_to_logits_fn, initial_ids, beam_size, length,
      vocab_size, alpha=0.0, eos_id=-1, stop_early=False)

  res = tf.expand_dims(ids[:, 0, :], axis=2)  # Pick first beam.
  return res[:, 1:]  # Remove the added all-zeros from ids.
예제 #10
0
  def testGreedyBatchOne(self):
    batch_size = 1
    beam_size = 1
    vocab_size = 2
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO

    # Test that beam search finds the most probable sequence.
    # These probabilities represent the following search
    #
    #               G0 (0)
    #                  / \
    #                /     \
    #              /         \
    #            /             \
    #         0(0.7)          1(0.3)
    #           / \
    #          /   \
    #         /     \
    #     0(0.4) 1(0.6)
    #        /\
    #       /  \
    #      /    \
    #    0(0.5) 1(0.5)
    # and the following decoding probabilities
    # 0000 - 0.7 * 0.4  * 0.1
    # 0001 - 0.7 * 0.4  * 0.9
    # 001 - 0.7 * 0.6 (Best)
    # 01 = 0.3
    #
    # 001 is the most likely sequence under these probabilities.
    probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]])

    def symbols_to_logits(ids):
      pos = tf.shape(ids)[1]
      logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
      return logits

    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1)

    with self.test_session():
      ids = final_ids.eval()
      probs = final_probs.eval()
    self.assertAllEqual([[[0, 0, 1]]], ids)
    self.assertAllClose([[0.7 * 0.6]], np.exp(probs))
예제 #11
0
    def beam_search_decode(self, features, hidden_feature, mode, problem_name):
        # prepare inputs to attention
        key = 'ori_seq' if self.params.label_transfer else 'seq'
        encoder_outputs = hidden_feature[key]
        max_seq_len = self.params.max_seq_len
        embedding_table = hidden_feature['embed_table']
        token_type_ids = features['segment_ids']
        num_classes = self.params.num_classes[problem_name]
        batch_size = modeling.get_shape_list(encoder_outputs,
                                             expected_rank=3)[0]
        hidden_size = self.params.bert_config.hidden_size

        if self.params.problem_type[problem_name] == 'seq2seq_text':
            embedding_table = hidden_feature['embed_table']
        else:
            embedding_table = tf.get_variable('tag_embed_table',
                                              shape=[num_classes, hidden_size])

        symbol_to_logit_fn = self._get_symbol_to_logit_fn(
            max_seq_len=max_seq_len,
            embedding_table=embedding_table,
            token_type_ids=token_type_ids,
            decoder=self.decoder,
            num_classes=num_classes,
            encoder_output=encoder_outputs,
            input_mask=features['input_mask'],
            params=self.params)

        # create cache for fast decode
        cache = {
            str(layer): {
                "key_layer": tf.zeros([batch_size, 0, hidden_size]),
                "value_layer": tf.zeros([batch_size, 0, hidden_size]),
            }
            for layer in range(self.params.decoder_num_hidden_layers)
        }
        # cache['encoder_outputs'] = encoder_outputs
        # cache['encoder_decoder_attention_mask'] = features['input_mask']
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)

        decode_ids, _ = beam_search.beam_search(
            symbols_to_logits_fn=symbol_to_logit_fn,
            initial_ids=initial_ids,
            states=cache,
            vocab_size=self.params.num_classes[problem_name],
            beam_size=self.params.beam_size,
            alpha=self.params.beam_search_alpha,
            decode_length=self.params.decode_max_seq_len,
            eos_id=self.params.eos_id[problem_name])
        # Get the top sequence for each batch element
        top_decoded_ids = decode_ids[:, 0, 1:]
        self.prob = top_decoded_ids
        return self.prob
예제 #12
0
  def testGreedyBatchOne(self):
    batch_size = 1
    beam_size = 1
    vocab_size = 2
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO

    # Test that beam search finds the most probable sequence.
    # These probabilities represent the following search
    #
    #               G0 (0)
    #                  / \
    #                /     \
    #              /         \
    #            /             \
    #         0(0.7)          1(0.3)
    #           / \
    #          /   \
    #         /     \
    #     0(0.4) 1(0.6)
    #        /\
    #       /  \
    #      /    \
    #    0(0.5) 1(0.5)
    # and the following decoding probabilities
    # 0000 - 0.7 * 0.4  * 0.1
    # 0001 - 0.7 * 0.4  * 0.9
    # 001 - 0.7 * 0.6 (Best)
    # 01 = 0.3
    #
    # 001 is the most likely sequence under these probabilities.
    probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]])

    def symbols_to_logits(ids):
      pos = tf.shape(ids)[1]
      logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
      return logits

    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1)

    with self.test_session():
      ids = final_ids.eval()
      probs = final_probs.eval()
    self.assertAllEqual([[[0, 0, 1]]], ids)
    self.assertAllClose([[0.7 * 0.6]], np.exp(probs))
예제 #13
0
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams):
    """Samples from the latent space in the autoencoder.

  Args:
    latents_dense_in: Tensor of shape [batch, length_q, ...]. Only the shape of
      its first two dimensions are used. length_q is the latent length, which is
      height * width * hparams.num_latents / (2**hparams.num_compress_steps).
    inputs: Tensor of shape [batch, length_kv, hparams.hidden_size]. Encodings
      to attend to in decoder.
    ed: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q,
      length_kv]. Encoder-decoder attention bias.
    embed: Callable which embeds discrete latent hot-vectors and a hidden size
      and returns dense vectors.
    hparams: HParams.

  Returns:
    Tensor of shape [batch, length].
  """
    def symbols_to_logits_fn(ids):
        """Go from ids to logits."""
        ids = tf.expand_dims(ids, axis=2)  # Ids start with added all-zeros.
        latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]])

        with tf.variable_scope(tf.get_variable_scope(), reuse=False):
            latents_dense = embed(
                tf.one_hot(latents_discrete, depth=2**hparams.bottleneck_bits),
                hparams.hidden_size)
            latents_pred = transformer_latent_decoder(latents_dense,
                                                      inputs,
                                                      ed,
                                                      hparams,
                                                      name="latent_prediction")
            logits = tf.layers.dense(latents_pred,
                                     2**hparams.bottleneck_bits,
                                     name="logits_dense")
            current_output_position = common_layers.shape_list(ids)[1] - 1
            logits = logits[:, current_output_position, :]
        return logits

    initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
    length = tf.shape(latents_dense_in)[1]
    ids, _, _ = beam_search.beam_search(symbols_to_logits_fn,
                                        initial_ids,
                                        1,
                                        length,
                                        2**hparams.bottleneck_bits,
                                        alpha=0.0,
                                        eos_id=-1,
                                        stop_early=False)

    res = tf.expand_dims(ids[:, 0, :], axis=2)  # Pick first beam.
    return res[:, 1:]  # Remove the added all-zeros from ids.
예제 #14
0
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams):
  """Samples from the latent space in the autoencoder.

  Args:
    latents_dense_in: Tensor of shape [batch, length_q, ...]. Only the shape of
      its first two dimensions are used. length_q is the latent length, which is
      height * width * hparams.num_latents / (2**hparams.num_compress_steps).
    inputs: Tensor of shape [batch, length_kv, hparams.hidden_size]. Encodings
      to attend to in decoder.
    ed: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q,
      length_kv]. Encoder-decoder attention bias.
    embed: Callable which embeds discrete latent hot-vectors and a hidden size
      and returns dense vectors.
    hparams: tf.contrib.training.HParams.

  Returns:
    Tensor of shape [batch, length].
  """

  def symbols_to_logits_fn(ids):
    """Go from ids to logits."""
    ids = tf.expand_dims(ids, axis=2)  # Ids start with added all-zeros.
    latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]])

    with tf.variable_scope(tf.get_variable_scope(), reuse=False):
      latents_dense = embed(
          tf.one_hot(latents_discrete, depth=2**hparams.bottleneck_bits),
          hparams.hidden_size)
      latents_pred = transformer_latent_decoder(
          latents_dense, inputs, ed, hparams, name="latent_prediction")
      logits = tf.layers.dense(
          latents_pred, 2**hparams.bottleneck_bits, name="logits_dense")
      current_output_position = common_layers.shape_list(ids)[1] - 1
      logits = logits[:, current_output_position, :]
    return logits

  initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32)
  length = tf.shape(latents_dense_in)[1]
  ids, _ = beam_search.beam_search(
      symbols_to_logits_fn,
      initial_ids,
      1,
      length,
      2**hparams.bottleneck_bits,
      alpha=0.0,
      eos_id=-1,
      stop_early=False)

  res = tf.expand_dims(ids[:, 0, :], axis=2)  # Pick first beam.
  return res[:, 1:]  # Remove the added all-zeros from ids.
예제 #15
0
  def testTPUBeam(self):
    batch_size = 1
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                 [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                 [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

    # The top beam is always selected so we should see the top beam's state
    # at each position, which is the one thats getting 3 added to it each step.
    expected_states = tf.constant([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]])

    def symbols_to_logits(_, i, states):
      # We have to assert the values of state inline here since we can't fetch
      # them out of the loop!
      with tf.control_dependencies(
          [tf.assert_equal(states["state"], expected_states[i])]):
        logits = tf.to_float(tf.log(probabilities[i, :]))

      states["state"] += tf.constant([[3.], [7.]])
      return logits, states

    states = {
        "state": tf.zeros((batch_size, 1)),
    }
    states["state"] = tf.placeholder_with_default(
        states["state"], shape=(None, 1))

    final_ids, _ = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        3.5,
        eos_id=1,
        states=states,
        use_tpu=True)

    with self.test_session() as sess:
      # Catch and fail so that the testing framework doesn't think it's an error
      try:
        sess.run(final_ids)
      except tf.errors.InvalidArgumentError as e:
        raise AssertionError(e.message)
    self.assertAllEqual([[[0, 2, 0, 1], [0, 2, 1, 0]]], final_ids)
  def testStateBeamTwo(self):
    batch_size = 1
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                 [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                 [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

    # The top beam is always selected so we should see the top beam's state
    # at each position, which is the one thats getting 3 added to it each step.
    expected_states = tf.constant([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]])

    def symbols_to_logits(ids, _, states):
      pos = tf.shape(ids)[1] - 1

      # We have to assert the values of state inline here since we can't fetch
      # them out of the loop!
      with tf.control_dependencies(
          [tf.assert_equal(states["state"], expected_states[pos])]):
        logits = tf.to_float(tf.log(probabilities[pos, :]))

      states["state"] += tf.constant([[3.], [7.]])
      return logits, states

    states = {
        "state": tf.zeros((batch_size, 1)),
    }
    states["state"] = tf.placeholder_with_default(
        states["state"], shape=(None, 1))

    final_ids, _, _ = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1,
        states=states)

    with self.test_session() as sess:
      # Catch and fail so that the testing framework doesn't think it's an error
      try:
        sess.run(final_ids)
      except tf.errors.InvalidArgumentError as e:
        raise AssertionError(e.message)
def beam_search_decoding(length=1000):
    initial_ids = tf.constant(char2idx['<start>'], tf.int32, [1])

    def symbols_to_logits(ids):
        logits = model.forward(ids)
        return logits[:, tf.shape(ids)[1] - 1, :]

    final_ids, final_probs, _ = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        5,
        length,
        len(char2idx),
        0.0,
        eos_id=char2idx['<end>'])

    return final_ids[0, 0, :]
예제 #18
0
def beam_search_decoding(length=20, beam_width=5):
    initial_ids = tf.fill([model.batch_size], GO)

    def symbols_to_logits(ids):
        x = tf.contrib.seq2seq.tile_batch(model.X, beam_width)
        logits = model.forward(x, ids, reuse=True)
        return logits[:, tf.shape(ids)[1] - 1, :]

    final_ids, final_probs, _ = beam_search.beam_search(symbols_to_logits,
                                                        initial_ids,
                                                        beam_width,
                                                        length,
                                                        len(vocab2id),
                                                        0.0,
                                                        eos_id=EOS)

    return final_ids
예제 #19
0
  def testNotGreedyBatchTwoBeamTwoWithAlpha(self):
    batch_size = 2
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    # Probabilities for position * batch * beam * vocab
    # Probabilities have been set such that with alpha = 3.5, the less probable
    # but longer sequence will have a better score than the shorter sequence
    # with higher log prob in batch 1, and the order will be reverse in batch
    # 2. That is, the shorter sequence will still have a higher score in spite
    # of the length penalty
    probabilities = tf.constant([[[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                  [[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]],
                                 [[[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                  [[0.3, 0.6, 0.1], [0.2, 0.4, 0.4]]],
                                 [[[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]],
                                  [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]])

    def symbols_to_logits(ids):
      pos = tf.shape(ids)[1]
      logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
      return logits

    final_ids, final_scores = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        3.5,
        eos_id=1)

    with self.test_session():
      ids = final_ids.eval()
      scores = final_scores.eval()
    self.assertAllEqual([[[0, 2, 0, 1], [0, 2, 1, 0]], [[0, 2, 1, 0],
                                                        [0, 2, 0, 1]]], ids)
    self.assertAllClose([[
        np.log(0.8 * 0.4 * 0.9) / (8. / 6.)**3.5,
        np.log(0.8 * 0.5) / (7. / 6.)**3.5
    ], [
        np.log(0.8 * 0.6) / (7. / 6.)**3.5,
        np.log(0.8 * 0.3 * 0.9) / (8. / 6.)**3.5
    ]], scores)
예제 #20
0
  def testNotGreedyBatchTwoBeamTwoWithAlpha(self):
    batch_size = 2
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    # Probabilities for position * batch * beam * vocab
    # Probabilities have been set such that with alpha = 3.5, the less probable
    # but longer sequence will have a better score than the shorter sequence
    # with higher log prob in batch 1, and the order will be reverse in batch
    # 2. That is, the shorter sequence will still have a higher score in spite
    # of the length penalty
    probabilities = tf.constant([[[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                  [[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]],
                                 [[[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                  [[0.3, 0.6, 0.1], [0.2, 0.4, 0.4]]],
                                 [[[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]],
                                  [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]])

    def symbols_to_logits(ids):
      pos = tf.shape(ids)[1]
      logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
      return logits

    final_ids, final_scores = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        3.5,
        eos_id=1)

    with self.test_session():
      ids = final_ids.eval()
      scores = final_scores.eval()
    self.assertAllEqual([[[0, 2, 0, 1], [0, 2, 1, 0]], [[0, 2, 1, 0],
                                                        [0, 2, 0, 1]]], ids)
    self.assertAllClose([[
        np.log(0.8 * 0.4 * 0.9) / (8. / 6.)**3.5,
        np.log(0.8 * 0.5) / (7. / 6.)**3.5
    ], [
        np.log(0.8 * 0.6) / (7. / 6.)**3.5,
        np.log(0.8 * 0.3 * 0.9) / (8. / 6.)**3.5
    ]], scores)
예제 #21
0
파일: evaluation.py 프로젝트: suyash/mlt
def predict(model,
            inputs,
            inpf,
            tarf,
            bos_id,
            eos_id,
            beam_size,
            vocab_size,
            alpha=1.0,
            decode_length=40):
    """
    inputs: already int encoded set of inputs, [batch_size, ?], tf.int32
    """

    batch_size = inputs.shape[0]
    initial_ids = [bos_id] * batch_size

    enc_input = tf.expand_dims(inputs, 1)
    enc_input = tf.tile(enc_input, [1, beam_size, 1])
    enc_input = tf.reshape(enc_input, [batch_size * beam_size, -1])

    def symbols_to_logits(ids):
        logits = model([
            enc_input,
            tf.tile(tf.expand_dims(inpf, 0), [tf.shape(ids)[0], 1]),
            ids,
            tf.tile(tf.expand_dims(tarf, 0), [tf.shape(ids)[0], 1]),
        ])

        logits = logits[0][:, -1, :]
        return logits

    x = beam_search(symbols_to_logits,
                    initial_ids,
                    beam_size,
                    decode_length,
                    vocab_size,
                    alpha=alpha,
                    eos_id=eos_id)

    ids = x[0]
    probs = x[1]

    return ids, probs
  def testStates(self):
    batch_size = 1
    beam_size = 1
    vocab_size = 2
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]])

    expected_states = tf.constant([[[0.]], [[1.]]])

    def symbols_to_logits(ids, _, states):
      pos = tf.shape(ids)[1] - 1
      # We have to assert the values of state inline here since we can't fetch
      # them out of the loop!
      with tf.control_dependencies(
          [tf.assert_equal(states["state"], expected_states[pos])]):
        logits = tf.to_float(tf.log(probabilities[pos, :]))

      states["state"] += 1
      return logits, states

    states = {
        "state": tf.zeros((batch_size, 1)),
    }
    states["state"] = tf.placeholder_with_default(
        states["state"], shape=(None, 1))

    final_ids, _, _ = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1,
        states=states)

    with self.test_session() as sess:
      # Catch and fail so that the testing framework doesn't think it's an error
      try:
        sess.run(final_ids)
      except tf.errors.InvalidArgumentError as e:
        raise AssertionError(e.message)
예제 #23
0
  def testStates(self):
    batch_size = 1
    beam_size = 1
    vocab_size = 2
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]])

    expected_states = tf.constant([[[0.]], [[1.]]])

    def symbols_to_logits(ids, _, states):
      pos = tf.shape(ids)[1] - 1
      # We have to assert the values of state inline here since we can't fetch
      # them out of the loop!
      with tf.control_dependencies(
          [tf.assert_equal(states["state"], expected_states[pos])]):
        logits = tf.to_float(tf.log(probabilities[pos, :]))

      states["state"] += 1
      return logits, states

    states = {
        "state": tf.zeros((batch_size, 1)),
    }
    states["state"] = tf.placeholder_with_default(
        states["state"], shape=(None, 1))

    final_ids, _ = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1,
        states=states)

    with self.test_session() as sess:
      # Catch and fail so that the testing framework doesn't think it's an error
      try:
        sess.run(final_ids)
      except tf.errors.InvalidArgumentError as e:
        raise AssertionError(e.message)
예제 #24
0
  def testShapes(self):
    batch_size = 2
    beam_size = 3
    vocab_size = 4
    decode_length = 10

    initial_ids = tf.constant([0, 0])  # GO

    def symbols_to_logits(_):
      # Just return random logits
      return tf.random_uniform((batch_size * beam_size, vocab_size))

    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size,
        0.)

    self.assertEqual(final_ids.get_shape().as_list(), [None, beam_size, None])

    self.assertEqual(final_probs.get_shape().as_list(), [batch_size, beam_size])
예제 #25
0
  def testShapes(self):
    batch_size = 2
    beam_size = 3
    vocab_size = 4
    decode_length = 10

    initial_ids = tf.constant([0, 0])  # GO

    def symbols_to_logits(_):
      # Just return random logits
      return tf.random_uniform((batch_size * beam_size, vocab_size))

    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size,
        0.)

    self.assertEqual(final_ids.get_shape().as_list(), [None, beam_size, None])

    self.assertEqual(final_probs.get_shape().as_list(), [batch_size, beam_size])
예제 #26
0
    def forward(self, inputs, labels, masks, training):
        b = tf.shape(labels)[0]
        label_inputs = tf.concat([tf.zeros((b, 1), tf.int64), labels[:, :-1]],
                                 axis=1)
        mask_inputs = tf.concat([tf.ones((b, 1), tf.int32), masks[:, :-1]],
                                axis=1)

        emb_inputs = tf.nn.embedding_lookup(self.embedding, label_inputs)
        rnn_masks = tf.cast(tf.expand_dims(mask_inputs, 2), tf.float32)
        rnn_inputs = tf.multiply(emb_inputs, rnn_masks)

        self.dropout.apply(rnn_inputs, training=training)

        h0 = self.fc.apply(inputs)
        rnn_outputs = self.decoder.apply(rnn_inputs, initial_state=[h0, h0])
        self.dropout.apply(rnn_outputs, training=training)

        logits = tf.matmul(rnn_outputs, self.embedding, transpose_b=True)

        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
                                                              logits=logits)
        loss = tf.reduce_sum(loss * tf.cast(masks, tf.float32), axis=-1)
        loss = tf.reduce_mean(loss)

        def pred_fn(x, i, states):
            e = tf.nn.embedding_lookup(self.embedding, x)
            r = self.decoder.apply(e, initial_state=states)[:, -1, :]
            o = tf.matmul(r, self.embedding, transpose_b=True)
            return o, states

        initial_ids = tf.ones((b, ), tf.int32) * self.sos_id
        beam_preds, _, _ = beam_search.beam_search(
            pred_fn,
            initial_ids,
            alpha=0.,
            beam_size=self.beam_size,
            decode_length=self.seq_len,
            vocab_size=self.output_size + 1,
            eos_id=self.eos_id,
            states=[h0, h0])

        return beam_preds[:, 0, 1:], loss
예제 #27
0
  def testNotGreedyBeamTwoWithStopEarly(self):
    batch_size = 1
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                 [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                 [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

    def symbols_to_logits(ids):
      pos = tf.shape(ids)[1]
      logits = tf.to_float(tf.log(probabilities[pos - 1, :]))
      return logits

    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1,
        stop_early=True)  # default value, but just to make this explicit

    with self.test_session():
      ids = final_ids.eval()
      probs = final_probs.eval()
    # given stop_early = True, the only 'assurance' is w.r.t. the first beam
    # (i.e., other beams may not even be completed)
    # so, we check only the first beam
    first_beam = ids[:, 0]
    first_probs = probs[:, 0]
    self.assertAllEqual([[0, 2, 1]], first_beam)
    self.assertAllClose([0.8 * 0.5], np.exp(first_probs))
예제 #28
0
  def _beam_decode_slow(self, features, decode_length, beam_size, top_beams,
                        alpha):
    """Slow version of Beam search decoding.

    Quadratic time in decode_length.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search
    """
    batch_size = common_layers.shape_list(features["inputs"])[0]
    batch_size = tf.Print(batch_size, [batch_size], "beam_decode batch_size=")

    def symbols_to_logits_fn(ids):
      """Go from ids to logits."""
      ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]])
      if "partial_targets" in features:
        pt = features["partial_targets"]
        pt_length = common_layers.shape_list(pt)[1]
        pt = tf.tile(pt, [1, beam_size])
        pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1])
        ids = tf.concat([pt, ids], axis=1)

      features["targets"] = ids
      self._coverage = None
      logits, _ = self(features)  # pylint: disable=not-callable
      # now self._coverage is a coverage tensor for the first datashard.
      # it has shape [batch_size] and contains floats between 0 and
      # source_length.
      modality = self.hparams.problems[self._problem_idx].target_modality
      if modality.top_is_pointwise:
        return tf.squeeze(logits, axis=[1, 2, 3])
      # -1 due to the pad above.
      current_output_position = common_layers.shape_list(ids)[1] - 1
      logits = logits[:, current_output_position, :, :]
      return tf.squeeze(logits, axis=[1, 2])

    initial_ids = tf.zeros([batch_size], dtype=tf.int32)

    if self.has_input:
      inputs_old = features["inputs"]
      features["inputs"] = tf.expand_dims(features["inputs"], 1)
      if len(features["inputs"].shape) < 5:
        features["inputs"] = tf.expand_dims(features["inputs"], 4)
      # Expand the inputs in to the beam size.
      features["inputs"] = tf.tile(features["inputs"], [1, beam_size, 1, 1, 1])
      s = common_layers.shape_list(features["inputs"])
      features["inputs"] = tf.reshape(features["inputs"],
                                      [s[0] * s[1], s[2], s[3], s[4]])

    target_modality = self.hparams.problems[self._problem_idx].target_modality
    vocab_size = target_modality.top_dimensionality
    # Setting decode length to input length + decode_length
    decode_length = tf.constant(decode_length)
    if "partial_targets" not in features:
      decode_length += common_layers.shape_list(features["inputs"])[1]
    ids, scores = beam_search.beam_search(
        symbols_to_logits_fn,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        alpha,
        stop_early=(top_beams == 1))

    # Set inputs back to the unexpanded inputs to not to confuse the Estimator!
    if self.has_input:
      features["inputs"] = inputs_old

    # Return `top_beams` decodings (also remove initial id from the beam search)
    return_scores = True  # TODO(lukaszkaiser): make it work multi-problem.
    if top_beams == 1:
      if return_scores:
        return {"outputs": ids[:, 0, 1:], "scores": scores}
      return ids[:, 0, 1:]
    else:
      if return_scores:
        return {"outputs": ids[:, :top_beams, 1:], "scores": scores}
      return ids[:, :top_beams, 1:]
예제 #29
0
  def _beam_decode_slow(self, features, decode_length, beam_size, top_beams,
                        last_position_only, alpha):
    """Slow version of Beam search decoding.

    Quadratic time in decode_length.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      last_position_only: a boolean, speed-up by computing last position only.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search
    """
    batch_size = tf.shape(features["inputs"])[0]
    batch_size = tf.Print(batch_size, [batch_size], "beam_decode batch_size=")

    def symbols_to_logits_fn(ids):
      """Go from ids to logits."""
      ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]])
      if "partial_targets" in features:
        pt = features["partial_targets"]
        pt_length = tf.shape(pt)[1]
        pt = tf.tile(pt, [1, beam_size])
        pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1])
        ids = tf.concat([pt, ids], axis=1)

      features["targets"] = ids
      self._coverage = None
      sharded_logits, _ = self.model_fn(
          features, False, last_position_only=last_position_only)
      # now self._coverage is a coverage tensor for the first datashard.
      # it has shape [batch_size] and contains floats between 0 and
      # source_length.
      logits = sharded_logits[0]  # Assuming we have one shard.
      if last_position_only:
        return tf.squeeze(logits, axis=[1, 2, 3])
      current_output_position = tf.shape(ids)[1] - 1  # -1 due to the pad above.
      logits = logits[:, current_output_position, :, :]
      return tf.squeeze(logits, axis=[1, 2])

    initial_ids = tf.zeros([batch_size], dtype=tf.int32)

    if self.has_input:
      inputs_old = features["inputs"]
      features["inputs"] = tf.expand_dims(features["inputs"], 1)
      if len(features["inputs"].shape) < 5:
        features["inputs"] = tf.expand_dims(features["inputs"], 4)
      # Expand the inputs in to the beam size.
      features["inputs"] = tf.tile(features["inputs"], [1, beam_size, 1, 1, 1])
      s = tf.shape(features["inputs"])
      features["inputs"] = tf.reshape(features["inputs"],
                                      [s[0] * s[1], s[2], s[3], s[4]])

    target_modality = self._hparams.problems[self._problem_idx].target_modality
    vocab_size = target_modality.top_dimensionality
    # Setting decode length to input length + decode_length
    decode_length = tf.constant(decode_length)
    if "partial_targets" not in features:
      decode_length += tf.shape(features["inputs"])[1]
    ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids,
                                          beam_size, decode_length, vocab_size,
                                          alpha)

    # Set inputs back to the unexpanded inputs to not to confuse the Estimator!
    if self.has_input:
      features["inputs"] = inputs_old

    # Return `top_beams` decodings (also remove initial id from the beam search)
    return_scores = False  # TODO(lukaszkaiser): make it work multi-problem.
    if top_beams == 1:
      if return_scores:
        return {"outputs": ids[:, 0, 1:], "scores": scores}
      return ids[:, 0, 1:]
    else:
      if return_scores:
        return {"outputs": ids[:, :top_beams, 1:], "scores": scores}
      return ids[:, :top_beams, 1:]
def fast_decode(encoder_output,
                encoder_decoder_attention_bias,
                symbols_to_logits_fn,
                hparams,
                decode_length,
                vocab_size,
                beam_size=1,
                top_beams=1,
                alpha=1.0,
                eos_id=beam_search.EOS_ID,
                batch_size=None):
    """Given encoder output and a symbols to logits function, does fast decoding.
  Implements both greedy and beam search decoding, uses beam search iff
  beam_size > 1, otherwise beam search related arguments are ignored.
  Args:
    encoder_output: Output from encoder.
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
    symbols_to_logits_fn: Incremental decoding; function mapping triple
      `(ids, step, cache)` to symbol logits.
    hparams: run hyperparameters
    decode_length: an integer.  How many additional timesteps to decode.
    vocab_size: Output vocabulary size.
    beam_size: number of beams.
    top_beams: an integer. How many of the beams to return.
    alpha: Float that controls the length penalty. larger the alpha, stronger
      the preference for slonger translations.
    eos_id: End-of-sequence symbol in beam search.
    batch_size: an integer scalar - must be passed if there is no input
  Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if top_beams == 1 or
              [batch_size, top_beams, <= decode_length] otherwise
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }
    Raises:
      NotImplementedError: If beam size > 1 with partial targets.
  """

    if encoder_output is not None:
        batch_size = common_layers.shape_list(encoder_output)[0]

    key_channels = hparams.attention_key_channels or hparams.hidden_size
    value_channels = hparams.attention_value_channels or hparams.hidden_size
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

    cache = {
        "layer_%d" % layer: {
            "k": tf.zeros([batch_size, 0, key_channels]),
            "v": tf.zeros([batch_size, 0, value_channels]),
        }
        for layer in range(num_layers)
    }

    if encoder_output is not None:
        cache["encoder_output"] = encoder_output
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

    if beam_size > 1:  # Beam Search
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)
        decoded_ids, scores = beam_search.beam_search(
            symbols_to_logits_fn,
            initial_ids,
            beam_size,
            decode_length,
            vocab_size,
            alpha,
            states=cache,
            eos_id=eos_id,
            stop_early=(top_beams == 1))

        if top_beams == 1:
            decoded_ids = decoded_ids[:, 0, 1:]
        else:
            decoded_ids = decoded_ids[:, :top_beams, 1:]
        """ t2t_csaky code """
        # do roulette wheel selection or inverse roulette wheel selection
        if hparams.roulette == "Normal" or hparams.roulette == "Inverse":
            if hparams.roulette == "Normal":
                probabilities = tf.pow(tf.constant(2.0), scores)
                start = 0
            else:
                probabilities = tf.subtract(tf.constant(1.0),
                                            tf.pow(tf.constant(2.0), scores))
                start = beam_size - hparams.roulette_beam_size

            ex_probs = tf.divide(probabilities, tf.reduce_sum(probabilities))
            #ex_probs=tf.nn.softmax(probabilities)

            # sample a number between 0 and 1
            wheel = tf.random_uniform([1])
            upper_bound = tf.constant(0.0)

            # change this as well if using inverse
            for i in range(start, hparams.roulette_beam_size):
                upper_bound = tf.add(ex_probs[:, i], upper_bound)
                truthValue = tf.squeeze(
                    tf.logical_and(wheel >= upper_bound - ex_probs[:, i],
                                   wheel <= upper_bound))
                decoded_ids, scores, i = tf.cond(
                    truthValue, lambda:
                    (decoded_ids[:, i, :], scores[:, i], beam_size), lambda:
                    (decoded_ids, scores, i))

    else:  # Greedy

        def inner_loop(i, finished, next_id, decoded_ids, cache):
            """One step of greedy decoding."""
            logits, cache = symbols_to_logits_fn(next_id, i, cache)
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            next_id = common_layers.sample_with_temperature(
                logits, temperature)
            finished |= tf.equal(next_id, eos_id)
            next_id = tf.expand_dims(next_id, axis=1)
            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return i + 1, finished, next_id, decoded_ids, cache

        def is_not_finished(i, finished, *_):
            return (i < decode_length) & tf.logical_not(
                tf.reduce_all(finished))

        decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
        finished = tf.fill([batch_size], False)
        next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
        _, _, _, decoded_ids, _ = tf.while_loop(
            is_not_finished,
            inner_loop,
            [tf.constant(0), finished, next_id, decoded_ids, cache],
            shape_invariants=[
                tf.TensorShape([]),
                tf.TensorShape([None]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                nest.map_structure(beam_search.get_state_shape_invariants,
                                   cache),
            ])
        scores = None

    return {
        "outputs": decoded_ids,
        "encoder_outputs": encoder_output,
        "scores": scores
    }
예제 #31
0
    def transformer_beam_search(self, encoder_outputs, encoder_attn_bias, encoder_embed_inputs_list,
                                sentence_complex_input_placeholder, emb_simple, w, b,
                                rule_id_input_placeholder, mem_contexts, mem_outputs, global_step,
                                score, obj, obj_tensors):
        # Use Beam Search in evaluation stage
        # Update [a, b, c] to [a, a, a, b, b, b, c, c, c] if beam_search_size == 3
        encoder_beam_outputs = tf.concat(
            [tf.tile(tf.expand_dims(encoder_outputs[o, :, :], axis=0),
                     [self.model_config.beam_search_size, 1, 1])
             for o in range(self.model_config.batch_size)], axis=0)

        encoder_embed_inputs = tf.stack(encoder_embed_inputs_list, axis=1)
        encoder_beam_embed_inputs = tf.concat(
            [tf.tile(tf.expand_dims(encoder_embed_inputs[o, :, :], axis=0),
                     [self.model_config.beam_search_size, 1, 1])
             for o in range(self.model_config.batch_size)], axis=0)

        encoder_attn_beam_bias = tf.concat(
            [tf.tile(tf.expand_dims(encoder_attn_bias[o, :, :, :], axis=0),
                     [self.model_config.beam_search_size, 1, 1, 1])
             for o in range(self.model_config.batch_size)], axis=0)

        if 'direct' in self.model_config.memory:
            obj_tensors['direct_bert_output_bak'] = obj_tensors['direct_bert_output']
            obj_tensors['direct_bert_bias_bak'] = obj_tensors['direct_bert_bias']
            obj_tensors['direct_bert_output'] = tf.concat(
                [tf.tile(tf.expand_dims(obj_tensors['direct_bert_output'][o, :, :], axis=0),
                         [self.model_config.beam_search_size, 1, 1])
                 for o in range(self.model_config.batch_size)], axis=0)
            obj_tensors['direct_bert_bias'] = tf.concat(
                [tf.tile(tf.expand_dims(obj_tensors['direct_bert_bias'][o, :, :, :], axis=0),
                         [self.model_config.beam_search_size, 1, 1, 1])
                 for o in range(self.model_config.batch_size)], axis=0)

        if self.model_config.subword_vocab_size or 'bert_token' in self.model_config.bert_mode:
            go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)[0]
            eos_id = self.data.vocab_simple.encode(constant.SYMBOL_END)[0]
        else:
            go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)
            eos_id = self.data.vocab_simple.encode(constant.SYMBOL_END)
        batch_go = tf.expand_dims(tf.tile(
            tf.expand_dims(self.embedding_fn(go_id, emb_simple), axis=0),
            [self.model_config.batch_size, 1]), axis=1)
        batch_go_beam = tf.concat(
            [tf.tile(tf.expand_dims(batch_go[o, :, :], axis=0),
                     [self.model_config.beam_search_size, 1, 1])
             for o in range(self.model_config.batch_size)], axis=0)

        def symbol_to_logits_fn(ids):
            cur_ids = ids[:, 1:]
            embs = tf.nn.embedding_lookup(emb_simple, cur_ids)
            embs = tf.concat([batch_go_beam, embs], axis=1)

            final_outputs, _, _ = self.decode_inputs_to_outputs(embs, encoder_beam_outputs, encoder_attn_beam_bias,
                                                                rule_id_input_placeholder, mem_contexts, mem_outputs,
                                                                global_step, score, obj_tensors=obj_tensors)

            decoder_logit_list = self.output_to_logit(final_outputs[:, -1, :], w, b)

            if self.model_config.pointer_mode:
                segment_mask = None
                if 'line_comp_segids' in obj:
                    segment_mask = obj['line_comp_segids']
                decoder_logit_list = word_distribution(
                    [decoder_logit_list], [final_outputs[:, -1, :]],
                    encoder_beam_outputs, encoder_beam_embed_inputs,
                    sentence_complex_input_placeholder,
                    obj_tensors, self.model_config, self.data, segment_mask, is_test=True)

            return decoder_logit_list

        beam_ids, beam_score = beam_search.beam_search(symbol_to_logits_fn,
                                                       tf.ones([self.model_config.batch_size], tf.int32) * go_id,
                                                       self.model_config.beam_search_size,
                                                       self.model_config.max_simple_sentence,
                                                       self.data.vocab_simple.vocab_size(),
                                                       self.model_config.penalty_alpha,
                                                       eos_id=eos_id
                                                       )
        top_beam_ids = beam_ids[:, 0, 1:]
        top_beam_ids = tf.pad(top_beam_ids,
                              [[0, 0],
                               [0, self.model_config.max_simple_sentence - tf.shape(top_beam_ids)[1]]])

        decoder_target_list = [tf.squeeze(d, 1)
                               for d in tf.split(top_beam_ids, self.model_config.max_simple_sentence, axis=1)]
        decoder_score = -beam_score[:, 0] / tf.to_float(tf.shape(top_beam_ids)[1])

        # Get outputs based on target ids
        decode_input_embs = tf.stack(self.embedding_fn(decoder_target_list, emb_simple), axis=1)
        tf.get_variable_scope().reuse_variables()
        if 'direct' in self.model_config.memory:
            obj_tensors['direct_bert_output'] = obj_tensors['direct_bert_output_bak']
            obj_tensors['direct_bert_bias'] = obj_tensors['direct_bert_bias_bak']
        final_outputs, decoder_outputs, _ = self.decode_inputs_to_outputs(decode_input_embs, encoder_outputs, encoder_attn_bias,
                                                                          rule_id_input_placeholder, mem_contexts,
                                                                          mem_outputs, global_step, score,
                                                                          obj_tensors=obj_tensors)
        output = ModelOutput(
            encoder_outputs=encoder_outputs,
            final_outputs_list=final_outputs,
            decoder_outputs_list=decoder_outputs,
            decoder_score=decoder_score,
            decoder_target_list=decoder_target_list,
            encoder_embed_inputs_list=encoder_embed_inputs_list
        )
        return output
예제 #32
0
def fast_decode(encoder_output,
                encoder_decoder_attention_bias,
                symbols_to_logits_fn,
                hparams,
                decode_length,
                vocab_size,
                beam_size=1,
                top_beams=1,
                alpha=1.0,
                eos_id=beam_search.EOS_ID,
                batch_size=None,
                sentence_cache=None,
                cache_flag=None):
    """Given encoder output and a symbols to logits function, does fast decoding.

  Implements both greedy and beam search decoding, uses beam search iff
  beam_size > 1, otherwise beam search related arguments are ignored.

  Args:
    encoder_output: Output from encoder.
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
    symbols_to_logits_fn: Incremental decoding; function mapping triple
      `(ids, step, cache)` to symbol logits.
    hparams: run hyperparameters
    decode_length: an integer.  How many additional timesteps to decode.
    vocab_size: Output vocabulary size.
    beam_size: number of beams.
    top_beams: an integer. How many of the beams to return.
    alpha: Float that controls the length penalty. larger the alpha, stronger
      the preference for slonger translations.
    eos_id: End-of-sequence symbol in beam search.
    batch_size: an integer scalar - must be passed if there is no input

  Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if top_beams == 1 or
              [batch_size, top_beams, <= decode_length] otherwise
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If beam size > 1 with partial targets.
  """
    if encoder_output is not None:
        batch_size = common_layers.shape_list(encoder_output)[0]

    key_channels = hparams.attention_key_channels or hparams.hidden_size
    value_channels = hparams.attention_value_channels or hparams.hidden_size
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

    cache = {
        "layer_%d" % layer: {
            "k": tf.zeros([batch_size, 0, key_channels]),
            "v": tf.zeros([batch_size, 0, value_channels]),
        }
        for layer in range(num_layers)
    }

    if encoder_output is not None:
        cache["encoder_output"] = encoder_output
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

    if beam_size > 1:  # Beam Search
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)
        decoded_ids, scores = beam_search.beam_search(
            lambda x: symbols_to_logits_fn(x)[:-1],
            initial_ids,
            beam_size,
            decode_length,
            vocab_size,
            alpha,
            states=cache,
            eos_id=eos_id,
            stop_early=(top_beams == 1))

        if top_beams == 1:
            decoded_ids = decoded_ids[:, 0, 1:]
        else:
            decoded_ids = decoded_ids[:, :top_beams, 1:]
    else:  # Greedy

        def inner_loop(cache_flag, i, finished, next_id, decoded_ids, cache):
            """One step of greedy decoding."""
            logits, cache, out = symbols_to_logits_fn(next_id, i, cache)
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            next_id = common_layers.sample_with_temperature(
                logits, temperature)
            finished |= tf.equal(next_id, eos_id)
            next_id = tf.expand_dims(next_id, axis=1)

            cache_flag = tf.py_func(sentence_cache.AddMultipleEntries,
                                    [next_id, out], tf.int64)
            cache_flag.set_shape(tf.TensorShape([]))

            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return cache_flag, i + 1, finished, next_id, decoded_ids, cache

        def is_not_finished(cache_flag, i, finished, *_):
            return (i < decode_length) & tf.logical_not(
                tf.reduce_all(finished))

        decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
        finished = tf.fill([batch_size], False)
        next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
        cache_flag, _, _, _, decoded_ids, _ = tf.while_loop(
            is_not_finished,
            inner_loop, [
                cache_flag,
                tf.constant(0), finished, next_id, decoded_ids, cache
            ],
            shape_invariants=[
                tf.TensorShape([]),
                tf.TensorShape([]),
                tf.TensorShape([None]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                nest.map_structure(beam_search.get_state_shape_invariants,
                                   cache),
            ])
        scores = None

    return {"outputs": decoded_ids + cache_flag, "scores": scores}
예제 #33
0
def fast_decode(encoder_output,
                encoder_decoder_attention_bias,
                symbols_to_logits_fn,
                hparams,
                decode_length,
                vocab_size,
                beam_size=1,
                top_beams=1,
                alpha=1.0,
                eos_id=beam_search.EOS_ID,
                batch_size=None):
  """Given encoder output and a symbols to logits function, does fast decoding.

  Implements both greedy and beam search decoding, uses beam search iff
  beam_size > 1, otherwise beam search related arguments are ignored.

  Args:
    encoder_output: Output from encoder.
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
    symbols_to_logits_fn: Incremental decoding; function mapping triple
      `(ids, step, cache)` to symbol logits.
    hparams: run hyperparameters
    decode_length: an integer.  How many additional timesteps to decode.
    vocab_size: Output vocabulary size.
    beam_size: number of beams.
    top_beams: an integer. How many of the beams to return.
    alpha: Float that controls the length penalty. larger the alpha, stronger
      the preference for slonger translations.
    eos_id: End-of-sequence symbol in beam search.
    batch_size: an integer scalar - must be passed if there is no input

  Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if top_beams == 1 or
              [batch_size, top_beams, <= decode_length] otherwise
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If beam size > 1 with partial targets.
  """
  if encoder_output is not None:
    batch_size = common_layers.shape_list(encoder_output)[0]

  key_channels = hparams.attention_key_channels or hparams.hidden_size
  value_channels = hparams.attention_value_channels or hparams.hidden_size
  num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

  cache = {
      "layer_%d" % layer: {
          "k": tf.zeros([batch_size, 0, key_channels]),
          "v": tf.zeros([batch_size, 0, value_channels]),
      }
      for layer in range(num_layers)
  }

  if encoder_output is not None:
    cache["encoder_output"] = encoder_output
    cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

  if beam_size > 1:  # Beam Search
    initial_ids = tf.zeros([batch_size], dtype=tf.int32)
    decoded_ids, scores = beam_search.beam_search(
        symbols_to_logits_fn,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        alpha,
        states=cache,
        eos_id=eos_id,
        stop_early=(top_beams == 1))

    if top_beams == 1:
      decoded_ids = decoded_ids[:, 0, 1:]
    else:
      decoded_ids = decoded_ids[:, :top_beams, 1:]
  else:  # Greedy

    def inner_loop(i, finished, next_id, decoded_ids, cache):
      """One step of greedy decoding."""
      logits, cache = symbols_to_logits_fn(next_id, i, cache)
      temperature = (0.0 if hparams.sampling_method == "argmax" else
                     hparams.sampling_temp)
      next_id = common_layers.sample_with_temperature(logits, temperature)
      finished |= tf.equal(next_id, eos_id)
      next_id = tf.expand_dims(next_id, axis=1)
      decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
      return i + 1, finished, next_id, decoded_ids, cache

    def is_not_finished(i, finished, *_):
      return (i < decode_length) & tf.logical_not(tf.reduce_all(finished))

    decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
    finished = tf.fill([batch_size], False)
    next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
    _, _, _, decoded_ids, _ = tf.while_loop(
        is_not_finished,
        inner_loop,
        [tf.constant(0), finished, next_id, decoded_ids, cache],
        shape_invariants=[
            tf.TensorShape([]),
            tf.TensorShape([None]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, None]),
            nest.map_structure(beam_search.get_state_shape_invariants, cache),
        ])
    scores = None

  return {"outputs": decoded_ids, "scores": scores}
예제 #34
0
    def _beam_decode_slow(self, features, decode_length, beam_size, top_beams,
                          alpha):
        """Slow version of Beam search decoding.

    Quadratic time in decode_length.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search
    """
        beam_size = 1
        if use_bottom_up_features:
            features["inputs"] = features["bottom_up_features"]

        batch_size = common_layers.shape_list(features["inputs"])[0]

        def symbols_to_logits_fn(ids):
            """Go from ids to logits."""
            ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]])
            if "partial_targets" in features:
                pt = features["partial_targets"]
                pt_length = common_layers.shape_list(pt)[1]
                pt = tf.tile(pt, [1, beam_size])
                pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1])
                ids = tf.concat([pt, ids], axis=1)

            features["targets"] = ids
            self._coverage = None
            logits, _ = self(features)  # pylint: disable=not-callable
            # now self._coverage is a coverage tensor for the first datashard.
            # it has shape [batch_size] and contains floats between 0 and
            # source_length.
            if self._problem_hparams:
                modality = self._problem_hparams.target_modality
                if modality.top_is_pointwise:
                    return tf.squeeze(logits, axis=[1, 2, 3])
            # -1 due to the pad above.
            current_output_position = common_layers.shape_list(ids)[1] - 1
            logits = logits[:, current_output_position, :, :]
            return tf.squeeze(logits, axis=[1, 2])

        initial_ids = tf.zeros([batch_size], dtype=tf.int32)

        if self.has_input:
            inputs_old = features["inputs"]
            features["inputs"] = tf.expand_dims(features["inputs"], 1)
            if len(features["inputs"].shape) < 5:
                features["inputs"] = tf.expand_dims(features["inputs"], 4)
            # Expand the inputs in to the beam size.
            features["inputs"] = tf.tile(features["inputs"],
                                         [1, beam_size, 1, 1, 1])
            s = common_layers.shape_list(features["inputs"])
            features["inputs"] = tf.reshape(features["inputs"],
                                            [s[0] * s[1], s[2], s[3], s[4]])

        features["bottom_up_features"] = features["inputs"]
        target_modality = self._problem_hparams.target_modality
        vocab_size = target_modality.top_dimensionality
        # Setting decode length to input length + decode_length
        decode_length = tf.constant(decode_length)
        if "partial_targets" not in features:
            decode_length += common_layers.shape_list(features["inputs"])[1]
        ids, scores = beam_search.beam_search(symbols_to_logits_fn,
                                              initial_ids,
                                              beam_size,
                                              decode_length,
                                              vocab_size,
                                              alpha,
                                              stop_early=(top_beams == 1))

        # Set inputs back to the unexpanded inputs to not to confuse the Estimator!
        if self.has_input:
            features["inputs"] = inputs_old
            features["bottom_up_features"] = inputs_old

        # Return `top_beams` decodings (also remove initial id from the beam search)
        # TODO(lukaszkaiser): make it work multi-problem.
        if top_beams == 1:
            samples = ids[:, 0, 1:]
        else:
            samples = ids[:, :top_beams, 1:]

        return {"outputs": samples, "scores": scores}
예제 #35
0
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams

    inputs = features["inputs"]
    batch_size = common_layers.shape_list(inputs)[0]
    target_modality = self._problem_hparams.target_modality
    if target_modality.is_class_modality:
      decode_length = 1
    else:
      decode_length = common_layers.shape_list(inputs)[1] + decode_length

    # TODO(llion): Clean up this reshaping logic.
    inputs = tf.expand_dims(inputs, axis=1)
    if len(inputs.shape) < 5:
      inputs = tf.expand_dims(inputs, axis=4)
    s = common_layers.shape_list(inputs)
    inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
    # _shard_features called to ensure that the variable names match
    inputs = self._shard_features({"inputs": inputs})["inputs"]
    input_modality = self._problem_hparams.input_modality["inputs"]
    with tf.variable_scope(input_modality.name):
      inputs = input_modality.bottom_sharded(inputs, dp)
    with tf.variable_scope("body"):
      encoder_output, encoder_decoder_attention_bias = dp(
          self.encode, inputs, features["target_space_id"], hparams,
          features=features)
    encoder_output = encoder_output[0]
    encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

    if hparams.pos == "timing":
      timing_signal = common_attention.get_timing_signal_1d(
          decode_length + 1, hparams.hidden_size)

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      # TODO(llion): Explain! Is this even needed?
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
          decode_length)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
            self.decode, targets, cache["encoder_output"],
            cache["encoder_decoder_attention_bias"], bias, hparams, cache,
            nonpadding=_features_to_nonpadding(features, "targets"))

      with tf.variable_scope(target_modality.name):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      return tf.squeeze(logits, axis=[1, 2, 3]), cache

    key_channels = hparams.attention_key_channels or hparams.hidden_size
    value_channels = hparams.attention_value_channels or hparams.hidden_size
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

    cache = {
        "layer_%d" % layer: {
            "k": tf.zeros([batch_size, 0, key_channels]),
            "v": tf.zeros([batch_size, 0, value_channels]),
        }
        for layer in range(num_layers)
    }

    # Set 2nd dim to None since it's not invariant in the tf.while_loop
    # Note: Tensor.set_shape() does not work here since it merges shape info.
    # TODO(llion); Find a more robust solution.
    # pylint: disable=protected-access
    if not context.in_eager_mode():
      for layer in cache:
        cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels])
        cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels])
    # pylint: enable=protected-access
    cache["encoder_output"] = encoder_output
    cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

    if beam_size > 1:  # Beam Search
      target_modality = (
          self._hparams.problems[self._problem_idx].target_modality)
      vocab_size = target_modality.top_dimensionality
      initial_ids = tf.zeros([batch_size], dtype=tf.int32)
      decoded_ids, scores = beam_search.beam_search(
          symbols_to_logits_fn,
          initial_ids,
          beam_size,
          decode_length,
          vocab_size,
          alpha,
          states=cache,
          stop_early=(top_beams == 1))

      if top_beams == 1:
        decoded_ids = decoded_ids[:, 0, 1:]
      else:
        decoded_ids = decoded_ids[:, :top_beams, 1:]
    else:  # Greedy

      def inner_loop(i, next_id, decoded_ids, cache):
        logits, cache = symbols_to_logits_fn(next_id, i, cache)
        temperature = (0.0 if hparams.sampling_method == "argmax" else
                       hparams.sampling_temp)
        next_id = tf.expand_dims(
            common_layers.sample_with_temperature(logits, temperature), axis=1)
        decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
        return i + 1, next_id, decoded_ids, cache

      decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
      scores = None
      next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
      _, _, decoded_ids, _ = tf.while_loop(
          # TODO(llion): Early stopping.
          lambda i, *_: tf.less(i, decode_length),
          inner_loop,
          [tf.constant(0), next_id, decoded_ids, cache],
          shape_invariants=[
              tf.TensorShape([]),
              tf.TensorShape([None, None]),
              tf.TensorShape([None, None]),
              nest.map_structure(lambda t: tf.TensorShape(t.shape), cache),
          ])

    return decoded_ids, scores
예제 #36
0
    def transformer_beam_search(self, abstr_outputs, abstr_bias, emb_kword, proj_w, proj_b, hist_vector=None):
        # Use Beam Search in evaluation stage
        # Update [a, b, c] to [a, a, a, b, b, b, c, c, c] if beam_search_size == 3
        encoder_beam_outputs = tf.concat(
            [tf.tile(tf.expand_dims(abstr_outputs[o, :, :], axis=0),
                     [self.model_config.beam_search_size, 1, 1])
             for o in range(self.model_config.batch_size)], axis=0)

        encoder_attn_beam_bias = tf.concat(
            [tf.tile(tf.expand_dims(abstr_bias[o, :, :, :], axis=0),
                     [self.model_config.beam_search_size, 1, 1, 1])
             for o in range(self.model_config.batch_size)], axis=0)

        hist_beam_vector = tf.concat(
            [tf.tile(tf.expand_dims(hist_vector[o, :, :], axis=0),
                     [self.model_config.beam_search_size, 1, 1])
             for o in range(self.model_config.batch_size)], axis=0)

        if self.model_config.subword_vocab_size:
            go_id = self.voc_kword.encode(constant.SYMBOL_GO)[0]
        else:
            go_id = self.voc_kword.encode(constant.SYMBOL_GO)
        batch_go = tf.expand_dims(tf.tile(
            tf.expand_dims(self.embedding_fn(go_id, emb_kword), axis=0),
            [self.model_config.batch_size, 1]), axis=1)
        batch_go_beam = tf.concat(
            [tf.tile(tf.expand_dims(batch_go[o, :, :], axis=0),
                     [self.model_config.beam_search_size, 1, 1])
             for o in range(self.model_config.batch_size)], axis=0)


        def symbol_to_logits_fn(ids):
            cur_ids = ids[:, 1:]

            embs = tf.nn.embedding_lookup(emb_kword, cur_ids)

            embs = tf.concat([batch_go_beam, embs], axis=1)

            final_outputs = self.decode_inputs_to_outputs(
                embs, encoder_beam_outputs, encoder_attn_beam_bias, hist_vector=hist_beam_vector)

            return self.output_to_logit(final_outputs[:, -1, :], proj_w, proj_b)

        beam_ids, beam_score = beam_search.beam_search(
            symbol_to_logits_fn,
            tf.zeros([self.model_config.batch_size], tf.int32),
            self.model_config.beam_search_size,
            self.model_config.max_kword_len,
            self.voc_kword.vocab_size(),
            0.6
        )

        top_beam_ids = beam_ids[:, 0, 1:]
        top_beam_ids = tf.pad(top_beam_ids,
                              [[0, 0],
                               [0, self.model_config.max_kword_len - tf.shape(top_beam_ids)[1]]])
        decoder_target_list = [tf.squeeze(d, 1)
                               for d in tf.split(top_beam_ids, self.model_config.max_kword_len, axis=1)]
        decoder_score = -beam_score[:, 0] / tf.to_float(tf.shape(top_beam_ids)[1])

        return decoder_score, top_beam_ids
예제 #37
0
    def decode_beam_search(self,
                           start_ids,
                           eos_id,
                           pad_id,
                           enc_output,
                           enc_mask,
                           scope="model"):
        batch_size = tf.shape(start_ids)[0]
        cache = {  # pylint: disable=g-complex-comprehension
            "layer_%d" % layer: {
                "uniform_avg": tf.zeros([batch_size, 1, self.model_dimension]),
            }
            for layer in range(self.num_layers)
        }
        cache["logits"] = tf.zeros([batch_size, 0, self.vocabulary_size])
        pos_indices = tf.range(self.max_dec_time_step, dtype=tf.int32)
        pos_indices = tf.reshape(pos_indices, [1, -1])
        pos_values = self.positional_embedding(pos_indices)

        def beam_search_tile(output, tile_pattern, final_shape):
            x = tf.tile(output, tile_pattern)
            x = tf.reshape(x, final_shape)
            return x

        enc_output_feature_dim = enc_output.get_shape().as_list()[2]
        enc_output = beam_search_tile(
            enc_output, [1, self.beam_size, 1],
            [batch_size * self.beam_size, -1, enc_output_feature_dim])
        enc_mask = beam_search_tile(enc_mask, [1, self.beam_size],
                                    [batch_size * self.beam_size, -1])

        def symbols_to_logits_fn(ids, step, cache):
            """Looks up ids to logits."""
            logging.info(
                "Running symbols to logits. ids=%s, step=%s, cache=%s", ids,
                step, cache)
            curr_id = ids[:, -1:]
            with tf.name_scope(scope):
                curr_embed = self.embedding(curr_id)
                input_mask = tf.ones(tf.shape(curr_embed)[:-1],
                                     dtype=tf.float32)
                if self.embedding_size != self.model_dimension:
                    curr_embed = self.input_bottleneck(curr_embed, input_mask)
                inputs = self.qact(
                    self.ln(curr_embed + pos_values[:, step:step + 1, :]))
                layer_out = self.transformer_uniform_attn_decoder(inputs,
                                                                  input_mask,
                                                                  enc_output,
                                                                  enc_mask,
                                                                  step=step +
                                                                  1,
                                                                  cache=cache)
                next_logits, _ = self.model_outputs(layer_out)
                cache["logits"] = tf.concat([cache["logits"], next_logits],
                                            axis=1)
                return next_logits, cache

        self.finished_seq, self.finished_scores, states = beam_search.beam_search(
            symbols_to_logits_fn,
            initial_ids=start_ids,
            beam_size=self.beam_size,
            decode_length=self.max_dec_time_step,
            vocab_size=self.vocabulary_size,
            alpha=0.6,
            eos_id=eos_id,
            states=cache)
        beam_ids = self.finished_seq[:, 0, 1:]
        beam_ids = tf.pad(
            beam_ids,
            [[0, 0], [0, self.max_dec_time_step - tf.shape(beam_ids)[1]]],
            constant_values=pad_id)
        logits = states["logits"][:, 0, :, :]
        logits = tf.pad(
            logits, [[0, 0], [0, self.max_dec_time_step - tf.shape(logits)[1]],
                     [0, 0]],
            constant_values=self.parameters.invalid_logit)
        return logits, beam_ids
예제 #38
0
def fast_decode(symbols_to_logits_fn,
                hparams,
                decode_length,
                vocab_size,
                beam_size=1,
                top_beams=1,
                alpha=1.0,
                eos_id=beam_search.EOS_ID,
                batch_size=None,
                force_decode_length=False,
                cache=None):
    """Given encoder output and a symbols to logits function, does fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
        encoder_output: Output from encoder.
        encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
          attention
        symbols_to_logits_fn: Incremental decoding; function mapping triple
          `(ids, step, cache)` to symbol logits.
        hparams: run hyperparameters
        decode_length: an integer.  How many additional timesteps to decode.
        vocab_size: Output vocabulary size.
        beam_size: number of beams.
        top_beams: an integer. How many of the beams to return.
        alpha: Float that controls the length penalty. larger the alpha, stronger
          the preference for longer translations.
        eos_id: End-of-sequence symbol in beam search.
        batch_size: an integer scalar - must be passed if there is no input
        force_decode_length: bool, whether to force the full decode length, or if
          False, stop when all beams hit eos_id.

    Returns:
        A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if top_beams == 1 or
              [batch_size, top_beams, <= decode_length] otherwise
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
        }

    Raises:
      NotImplementedError: If beam size > 1 with partial targets.
    """

    if beam_size > 1:  # Beam Search
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)
        decoded_ids, scores, cache = beam_search.beam_search(
            symbols_to_logits_fn,
            initial_ids,
            beam_size,
            decode_length,
            vocab_size,
            alpha,
            states=cache,
            eos_id=eos_id,
            stop_early=(top_beams == 1))

        if top_beams == 1:
            decoded_ids = decoded_ids[:, 0, 1:]
            scores = scores[:, 0]
        else:
            decoded_ids = decoded_ids[:, :top_beams, 1:]
            scores = scores[:, :top_beams]

    else:  # Greedy

        def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
            """One step of greedy decoding."""
            logits, cache = symbols_to_logits_fn(next_id, i, cache)
            log_probs = common_layers.log_prob_from_logits(logits)
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            next_id = common_layers.sample_with_temperature(
                logits, temperature)
            hit_eos |= tf.equal(next_id, eos_id)

            log_prob_indices = tf.stack(
                [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
            log_prob += tf.gather_nd(log_probs, log_prob_indices)

            next_id = tf.expand_dims(next_id, axis=1)
            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob

        def is_not_finished(i, hit_eos, *_):
            finished = i >= decode_length
            if not force_decode_length:
                finished |= tf.reduce_all(hit_eos)
            return tf.logical_not(finished)

        decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
        hit_eos = tf.fill([batch_size], False)
        next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
        initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
        _, _, _, decoded_ids, cache, log_prob = tf.while_loop(
            is_not_finished,
            inner_loop, [
                tf.constant(0), hit_eos, next_id, decoded_ids, cache,
                initial_log_prob
            ],
            shape_invariants=[
                tf.TensorShape([]),
                tf.TensorShape([None]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                tf.contrib.framework.nest.map_structure(
                    beam_search.get_state_shape_invariants, cache),
                tf.TensorShape([None]),
            ])
        scores = log_prob

    cache["outputs"] = decoded_ids
    cache["scores"] = scores

    return cache
예제 #39
0
def fast_decode(
    encoder_output,
    encoder_decoder_attention_bias,
    symbols_to_logits_fn,
    hparams,
    decode_length,
    vocab_size,
    init_cache_fn=_init_transformer_cache,
    beam_size=1,
    top_beams=1,
    alpha=1.0,
    sos_id=0,
    eos_id=beam_search.EOS_ID,
    batch_size=None,
    force_decode_length=False,
    scope_prefix='body/',
    sampling_temperature=0.0,
    top_k=-1,
    cache=None,
):
    """Given encoder output and a symbols to logits function, does fast decoding.

  Implements both greedy and beam search decoding, uses beam search iff
  beam_size > 1, otherwise beam search related arguments are ignored.

  Args:
    encoder_output: Output from encoder.
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
    symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids,
      step, cache)` to symbol logits.
    hparams: run hyperparameters
    decode_length: an integer.  How many additional timesteps to decode.
    vocab_size: Output vocabulary size.
    init_cache_fn: Function that returns the initial cache dict.
    beam_size: number of beams.
    top_beams: an integer. How many of the beams to return.
    alpha: Float that controls the length penalty. larger the alpha, stronger
      the preference for longer translations.
    sos_id: End-of-sequence symbol in beam search.
    eos_id: End-of-sequence symbol in beam search.
    batch_size: an integer scalar - must be passed if there is no input
    force_decode_length: bool, whether to force the full decode length, or if
      False, stop when all beams hit eos_id.
    scope_prefix: str, prefix for decoder layer variable scopes.
    sampling_temperature: scalar, temperature with which to sample.
    top_k: scalar, sample only top k.
    cache: cache dictionary for additional predictions.

  Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if top_beams == 1 or
              [batch_size, top_beams, <= decode_length] otherwise
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }
  """
    if encoder_output is not None:
        batch_size = common_layers.shape_list(encoder_output)[0]

    cache = init_cache_fn(
        cache=cache,
        hparams=hparams,
        batch_size=batch_size,
        attention_init_length=0,
        encoder_output=encoder_output,
        encoder_decoder_attention_bias=encoder_decoder_attention_bias,
        scope_prefix=scope_prefix,
    )

    if beam_size > 1:  # Beam Search
        initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32)
        decoded_ids, scores, cache = beam_search.beam_search(
            symbols_to_logits_fn,
            initial_ids,
            beam_size,
            decode_length,
            vocab_size,
            alpha,
            states=cache,
            eos_id=eos_id,
            stop_early=(top_beams == 1),
        )

        if top_beams == 1:
            decoded_ids = decoded_ids[:, 0, 1:]
            scores = scores[:, 0]
        else:
            decoded_ids = decoded_ids[:, :top_beams, 1:]
            scores = scores[:, :top_beams]
    else:

        def inner_loop(
            i,
            hit_eos,
            next_id,
            next_id_tag,
            decoded_ids,
            decoded_ids_tag,
            cache,
            log_prob,
        ):
            """One step of greedy decoding."""
            logits, logits_tag, cache = symbols_to_logits_fn(
                next_id, next_id_tag, i, cache)
            log_probs = common_layers.log_prob_from_logits(logits)
            temperature = sampling_temperature
            if hparams.sampling_method == 'random_per_example':
                next_id = common_layers.sample_temperature_per_example(
                    logits, temperature, top_k)
            else:
                if hparams.sampling_method == 'argmax':
                    temperature = 0.0
                next_id = common_layers.sample_with_temperature(
                    logits, temperature, top_k)

            if hparams.sampling_method == 'random_per_example':
                next_id_tag = common_layers.sample_temperature_per_example(
                    logits_tag, temperature, top_k)
            else:
                if hparams.sampling_method == 'argmax':
                    temperature = 0.0
                next_id_tag = common_layers.sample_with_temperature(
                    logits_tag, temperature, top_k)

            log_prob_indices = tf.stack(
                [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
            log_prob += tf.gather_nd(
                log_probs, log_prob_indices) * (1 - tf.to_float(hit_eos))
            hit_eos |= tf.equal(next_id, eos_id)

            next_id = tf.expand_dims(next_id, axis=1)
            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            next_id_tag = tf.expand_dims(next_id_tag, axis=1)
            decoded_ids_tag = tf.concat([decoded_ids_tag, next_id_tag], axis=1)

            return (
                i + 1,
                hit_eos,
                next_id,
                next_id_tag,
                decoded_ids,
                decoded_ids_tag,
                cache,
                log_prob,
            )

        def is_not_finished(i, hit_eos, *_):
            finished = i >= decode_length
            if not force_decode_length:
                finished |= tf.reduce_all(hit_eos)
            return tf.logical_not(finished)

        decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
        decoded_ids_tag = tf.zeros([batch_size, 0], dtype=tf.int64)
        hit_eos = tf.fill([batch_size], False)
        next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64)
        next_id_tag = sos_id * tf.ones([batch_size, 1], dtype=tf.int64)
        initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)

        _, _, _, _, decoded_ids, decoded_ids_tag, cache, log_prob = tf.while_loop(
            is_not_finished,
            inner_loop,
            [
                tf.constant(0),
                hit_eos,
                next_id,
                next_id_tag,
                decoded_ids,
                decoded_ids_tag,
                cache,
                initial_log_prob,
            ],
            shape_invariants=[
                tf.TensorShape([]),
                tf.TensorShape([None]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                nest.map_structure(beam_search.get_state_shape_invariants,
                                   cache),
                tf.TensorShape([None]),
            ],
        )
        scores = log_prob

    return {
        'outputs': decoded_ids,
        'outputs_tag': decoded_ids_tag,
        'scores': scores,
        'cache': cache,
    }
예제 #40
0
    def _fast_decode(self,
                     features,
                     decode_length,
                     last_position_only=True,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      last_position_only: MUST be true for fast decoding!
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search

    Raises:
      ValueError: If last_position_only if False
      NotImplementedError: If there are multiple data shards.
    """
        if not last_position_only:
            raise ValueError(
                "Fast decoding only deals with the last positions!")
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams

        inputs = features["inputs"]
        batch_size = tf.shape(inputs)[0]
        target_modality = self._problem_hparams.target_modality
        if t2t_model.is_class_modality(target_modality):
            decode_length = 1
        else:
            decode_length = tf.shape(inputs)[1] + decode_length

        # TODO(llion): Clean up this reshaping logic.
        inputs = tf.expand_dims(inputs, axis=1)
        if len(inputs.shape) < 5:
            inputs = tf.expand_dims(inputs, axis=4)
        s = tf.shape(inputs)
        inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.input_modality["inputs"]
        with tf.variable_scope(input_modality.name):
            inputs = input_modality.bottom_sharded(inputs, dp)
        with tf.variable_scope("body"):
            encoder_output, encoder_decoder_attention_bias = dp(
                self.encode, inputs, features["target_space_id"], hparams)
        encoder_output = encoder_output[0]
        encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode, targets,
                                  cache["encoder_output"],
                                  cache["encoder_decoder_attention_bias"],
                                  bias, hparams, cache)

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            return tf.squeeze(logits, axis=[1, 2, 3]), cache

        key_channels = hparams.attention_key_channels or hparams.hidden_size
        value_channels = hparams.attention_value_channels or hparams.hidden_size
        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

        cache = {
            "layer_%d" % layer: {
                "k": tf.zeros([batch_size, 0, key_channels]),
                "v": tf.zeros([batch_size, 0, value_channels]),
            }
            for layer in range(num_layers)
        }

        # Set 2nd dim to None since it's not invariant in the tf.while_loop
        # Note: Tensor.set_shape() does not work here since it merges shape info.
        # TODO(llion); Find a more robust solution.
        # pylint: disable=protected-access
        for layer in cache:
            cache[layer]["k"]._shape = tf.TensorShape(
                [None, None, key_channels])
            cache[layer]["v"]._shape = tf.TensorShape(
                [None, None, value_channels])
        # pylint: enable=protected-access
        cache["encoder_output"] = encoder_output
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        if beam_size > 1:  # Beam Search
            target_modality = (
                self._hparams.problems[self._problem_idx].target_modality)
            vocab_size = target_modality.top_dimensionality
            initial_ids = tf.zeros([batch_size], dtype=tf.int32)
            decoded_ids, _ = beam_search.beam_search(symbols_to_logits_fn,
                                                     initial_ids,
                                                     beam_size,
                                                     decode_length,
                                                     vocab_size,
                                                     alpha,
                                                     states=cache)

            if top_beams == 1:
                decoded_ids = decoded_ids[:, 0, 1:]
            else:
                decoded_ids = decoded_ids[:, :top_beams, 1:]
        else:  # Greedy

            def inner_loop(i, next_id, decoded_ids, cache):
                logits, cache = symbols_to_logits_fn(next_id, i, cache)
                next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1)
                decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
                return i + 1, next_id, decoded_ids, cache

            decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
            next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
            _, _, decoded_ids, _ = tf.while_loop(
                # TODO(llion): Early stopping.
                lambda i, *_: tf.less(i, decode_length),
                inner_loop,
                [tf.constant(0), next_id, decoded_ids, cache],
                shape_invariants=[
                    tf.TensorShape([]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                    nest.map_structure(lambda t: tf.TensorShape(t.shape),
                                       cache),
                ])

        return decoded_ids
예제 #41
0
def fast_decode(wav_encoder_output,
                txt_encoder_output,
                wav_enc_dec_attention_bias,
                txt_enc_dec_attention_bias,
                symbols_to_logits_fn,
                hparams,
                decode_length,
                vocab_size,
                beam_size=1,
                top_beams=1,
                alpha=1.0,
                eos_id=beam_search.EOS_ID,
                batch_size=None,
                force_decode_length=False):
  """ implement greedy and beam search
  Args:
    wav_encoder_output: Output from wav encoder.
    txt_encoder_output: Output from txt encoder.
    wav_enc_dec_attention_bias: a bias tensor for use in enc-dec attention
      over wav inputs
    txt_enc_dec_attention_bias: a bias tensor for use in enc-dec attention
      over txt inputs
    symbols_to_logits_fn: Incremental decoding; function mapping triple
      `(ids, step, cache)` to symbol logits.
    hparams: run hyperparameters
    decode_length: an integer.  How many additional timesteps to decode.
    vocab_size: Output vocabulary size.
    beam_size: number of beams.
    top_beams: an integer. How many of the beams to return.
    alpha: Float that controls the length penalty. larger the alpha, stronger
      the preference for longer translations.
    eos_id: End-of-sequence symbol in beam search.
    batch_size: an integer scalar - must be passed if there is no input
    force_decode_length: bool, whether to force the full decode length, or if
      False, stop when all beams hit eos_id.

  Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if top_beams == 1 or
              [batch_size, top_beams, <= decode_length] otherwise
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If beam size > 1 with partial targets.
  """
  if wav_encoder_output is not None:
    batch_size = common_layers.shape_list(wav_encoder_output)[0]

  key_channels = hparams.attention_key_channels or hparams.hidden_size
  value_channels = hparams.attention_value_channels or hparams.hidden_size
  num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

  cache = {
      "layer_%d" % layer: {
          "k": tf.zeros([batch_size, 0, key_channels]),
          "v": tf.zeros([batch_size, 0, value_channels]),
          "f": tf.zeros([batch_size, 0, hparams.hidden_size]),
      } for layer in range(num_layers)
  }

  if txt_encoder_output and wav_encoder_output:
    cache["wav_enc_output"] = wav_encoder_output
    cache["txt_enc_output"] = txt_encoder_output
    cache["wav_enc_dec_attention_bias"] = wav_enc_dec_attention_bias
    cache["txt_enc_dec_attention_bias"] = txt_enc_dec_attention_bias

  if beam_size > 1:  # Beam Search
    initial_ids = tf.zeros([batch_size], dtype=tf.int32)
    decoded_ids, scores = beam_search.beam_search(
        symbols_to_logits_fn,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        alpha,
        states=cache,
        eos_id=eos_id,
        stop_early=(top_beams == 1))

    if top_beams == 1:
      decoded_ids = decoded_ids[:, 0, 1:]
      scores = scores[:, 0]
    else:
      decoded_ids = decoded_ids[:, :top_beams, 1:]
      scores = scores[:, :top_beams]
  else: # Greedy search
    # pass
    def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
      """One step of greedy decoding."""
      logits, cache = symbols_to_logits_fn(next_id, i, cache)
      log_probs = common_layers.log_prob_from_logits(logits)
      temperature = (0.0 if hparams.sampling_method == "argmax" else
                     hparams.sampling_temp)
      next_id = common_layers.sample_with_temperature(logits, temperature)
      hit_eos |= tf.equal(next_id, eos_id)

      log_prob_indices = tf.stack(
          [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
      log_prob += tf.gather_nd(log_probs, log_prob_indices)

      next_id = tf.expand_dims(next_id, axis=1)
      decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
      return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob

    def is_not_finished(i, hit_eos, *_):
      finished = i >= decode_length
      if not force_decode_length:
        finished |= tf.reduce_all(hit_eos)
      return tf.logical_not(finished)

    decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
    hit_eos = tf.fill([batch_size], False)
    next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
    initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
    _, _, _, decoded_ids, _, log_prob = tf.while_loop(
        is_not_finished,
        inner_loop, [
            tf.constant(0), hit_eos, next_id, decoded_ids, cache,
            initial_log_prob
        ],
        shape_invariants=[
            tf.TensorShape([]),
            tf.TensorShape([None]),
            tf.TensorShape([None, None]),
            tf.TensorShape([None, None]),
            nest.map_structure(beam_search.get_state_shape_invariants, cache),
            tf.TensorShape([None]),
        ])
    scores = log_prob

  return {"outputs": decoded_ids, "scores": scores}
예제 #42
0
def fast_decode(encoder_output,
                encoder_decoder_attention_bias,
                symbols_to_logits_fn,
                hparams,
                decode_length,
                vocab_size,
                beam_size=1,
                top_beams=1,
                alpha=1.0,
                eos_id=beam_search.EOS_ID):
    """Given encoder output and a symbols to logits function, does fast decoding.

  Implements both greedy and beam search decoding, uses beam search iff
  beam_size > 1, otherwise beam search related arguments are ignored.

  Args:
    encoder_output: Output from encoder.
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
    symbols_to_logits_fn: Incremental decoding; function mapping triple
      `(ids, step, cache)` to symbol logits.
    hparams: run hyperparameters
    decode_length: an integer.  How many additional timesteps to decode.
    vocab_size: Output vocabulary size.
    beam_size: number of beams.
    top_beams: an integer. How many of the beams to return.
    alpha: Float that controls the length penalty. larger the alpha, stronger
      the preference for slonger translations.
    eos_id: End-of-sequence symbol in beam search.

  Returns:
    Pair of tensors `(decoded_ids, scores)`, where `decoded_ids` is a 2-d or 3-d
    (when doing beam search with top_beams > 1) tensor containing result of
    decoding, and `scores` is the beam search scores.
  """
    batch_size = common_layers.shape_list(encoder_output)[0]

    key_channels = hparams.attention_key_channels or hparams.hidden_size
    value_channels = hparams.attention_value_channels or hparams.hidden_size
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

    cache = {
        "layer_%d" % layer: {
            "k": tf.zeros([batch_size, 0, key_channels]),
            "v": tf.zeros([batch_size, 0, value_channels]),
        }
        for layer in range(num_layers)
    }

    # Set 2nd dim to None since it's not invariant in the tf.while_loop
    # Note: Tensor.set_shape() does not work here since it merges shape info.
    # TODO(llion); Find a more robust solution.
    # pylint: disable=protected-access
    if not context.in_eager_mode():
        for layer in cache:
            cache[layer]["k"]._shape = tf.TensorShape(
                [None, None, key_channels])
            cache[layer]["v"]._shape = tf.TensorShape(
                [None, None, value_channels])
    # pylint: enable=protected-access
    cache["encoder_output"] = encoder_output
    cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

    if beam_size > 1:  # Beam Search
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)
        decoded_ids, scores = beam_search.beam_search(
            symbols_to_logits_fn,
            initial_ids,
            beam_size,
            decode_length,
            vocab_size,
            alpha,
            states=cache,
            eos_id=eos_id,
            stop_early=(top_beams == 1))

        if top_beams == 1:
            decoded_ids = decoded_ids[:, 0, 1:]
        else:
            decoded_ids = decoded_ids[:, :top_beams, 1:]
    else:  # Greedy

        def inner_loop(i, next_id, decoded_ids, cache):
            logits, cache = symbols_to_logits_fn(next_id, i, cache)
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            next_id = tf.expand_dims(common_layers.sample_with_temperature(
                logits, temperature),
                                     axis=1)
            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return i + 1, next_id, decoded_ids, cache

        decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
        scores = None
        next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
        _, _, decoded_ids, _ = tf.while_loop(
            # TODO(llion): Early stopping.
            lambda i, *_: tf.less(i, decode_length),
            inner_loop,
            [tf.constant(0), next_id, decoded_ids, cache],
            shape_invariants=[
                tf.TensorShape([]),
                tf.TensorShape([None, None]),
                tf.TensorShape([None, None]),
                nest.map_structure(lambda t: tf.TensorShape(t.shape), cache),
            ])

    return decoded_ids, scores
예제 #43
0
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        """
		Fast decoding.
		Implements both greedy and beam search decoding, uses beam search iff
		beam_size > 1, otherwise beam search related arguments are ignored.
		Args:
			features: a map of string to model  features.
			decode_length: an integer.  How many additional timesteps to decode.
			beam_size: number of beams.
			top_beams: an integer. How many of the beams to return.
			alpha: Float that controls the length penalty. larger the alpha, stronger
			the preference for slonger translations.
		Returns:
			 samples: an integer `Tensor`. Top samples from the beam search
		Raises:
			NotImplementedError: If there are multiple data shards.
		"""
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams

        inputs = features["inputs"]
        batch_size = common_layers.shape_list(inputs)[0]
        target_modality = self._problem_hparams.target_modality
        if target_modality.is_class_modality:
            decode_length = 1
        else:
            decode_length = common_layers.shape_list(inputs)[1] + decode_length

        # TODO(llion): Clean up this reshaping logic.
        inputs = tf.expand_dims(inputs, axis=1)
        if len(inputs.shape) < 5:
            inputs = tf.expand_dims(inputs, axis=4)
        s = common_layers.shape_list(inputs)
        inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.input_modality["inputs"]
        with tf.variable_scope(input_modality.name):
            inputs = input_modality.bottom_sharded(inputs, dp)
        with tf.variable_scope("body"):
            encoder_output, encoder_decoder_attention_bias = dp(
                self.encode,
                inputs,
                features["target_space_id"],
                hparams,
                features=features)
        encoder_output = encoder_output[0]
        encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.
			This includes:
			- Embedding the ids.
			- Flattening to 3D tensor.
			- Optionally adding timing signals.
			Args:
			targets: inputs ids to the decoder. [batch_size, 1]
			i: scalar, Step number of the decoding loop.
			Returns:
			Processed targets [batch_size, 1, hidden_dim]
			"""
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(
                    self.decode,
                    targets,
                    cache["encoder_output"],
                    cache["encoder_decoder_attention_bias"],
                    bias,
                    hparams,
                    cache,
                    nonpadding=transformer._features_to_nonpadding(
                        features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            return tf.squeeze(logits, axis=[1, 2, 3]), cache

        key_channels = hparams.attention_key_channels or hparams.hidden_size
        value_channels = hparams.attention_value_channels or hparams.hidden_size
        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

        cache = {
            "layer_%d" % layer: {
                "k": tf.zeros([batch_size, 0, key_channels]),
                "v": tf.zeros([batch_size, 0, value_channels]),
            }
            for layer in range(num_layers)
        }

        # Set 2nd dim to None since it's not invariant in the tf.while_loop
        # Note: Tensor.set_shape() does not work here since it merges shape info.
        # TODO(llion); Find a more robust solution.
        # pylint: disable=protected-access
        if not context.in_eager_mode():
            for layer in cache:
                cache[layer]["k"]._shape = tf.TensorShape(
                    [None, None, key_channels])
                cache[layer]["v"]._shape = tf.TensorShape(
                    [None, None, value_channels])
        # pylint: enable=protected-access
        cache["encoder_output"] = encoder_output
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        if beam_size > 1:  # Beam Search
            target_modality = (
                self._hparams.problems[self._problem_idx].target_modality)
            vocab_size = target_modality.top_dimensionality
            initial_ids = tf.zeros([batch_size], dtype=tf.int32)
            decoded_ids, scores = beam_search.beam_search(
                symbols_to_logits_fn,
                initial_ids,
                beam_size,
                decode_length,
                vocab_size,
                alpha,
                states=cache,
                stop_early=(top_beams == 1))

            decoded_ids = decoded_ids[:, :, 1:]

            # do roulette wheel selection or inverse roulette wheel selection
            if self._hparams.roulette == "Normal" or self._hparams.roulette == "Inverse":
                if self._hparams.roulette == "Normal":
                    probabilities = tf.pow(tf.constant(2.0), scores)
                    start = 0
                else:
                    probabilities = tf.subtract(
                        tf.constant(1.0), tf.pow(tf.constant(2.0), scores))
                    start = beam_size - self._hparams.roulette_beam_size

                summ = tf.reduce_sum(probabilities)
                ex_probs = tf.divide(probabilities, summ)
                #ex_probs=tf.nn.softmax(probabilities)

                # sample a number between 0 and 1
                wheel = tf.random_uniform([1])
                upper_bound = tf.constant(0.0)

                # change this as well if using inverse
                for i in range(start, self._hparams.roulette_beam_size):
                    upper_bound = tf.add(ex_probs[:, i], upper_bound)
                    truthValue = tf.squeeze(
                        tf.logical_and(wheel >= upper_bound - ex_probs[:, i],
                                       wheel <= upper_bound))
                    decoded_ids, scores, i = tf.cond(
                        truthValue, lambda:
                        (decoded_ids[:, i, :], scores[:, i], beam_size),
                        lambda: (decoded_ids, scores, i))

        else:  # Greedy

            def inner_loop(i, next_id, decoded_ids, cache):
                logits, cache = symbols_to_logits_fn(next_id, i, cache)
                temperature = (0.0 if hparams.sampling_method == "argmax" else
                               hparams.sampling_temp)
                next_id = tf.expand_dims(common_layers.sample_with_temperature(
                    logits, temperature),
                                         axis=1)
                decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
                return i + 1, next_id, decoded_ids, cache

            decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
            scores = None
            next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
            _, _, decoded_ids, _ = tf.while_loop(
                # TODO(llion): Early stopping.
                lambda i, *_: tf.less(i, decode_length),
                inner_loop,
                [tf.constant(0), next_id, decoded_ids, cache],
                shape_invariants=[
                    tf.TensorShape([]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                    nest.map_structure(lambda t: tf.TensorShape(t.shape),
                                       cache),
                ])

        return decoded_ids, scores
예제 #44
0
  def _beam_decode(self, features, decode_length, beam_size, top_beams,
                   last_position_only, alpha):
    """Beam search decoding.

    Args:
      features: an map of string to `Tensor`
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      last_position_only: a boolean, speed-up by computing last position only.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search
    """

    def symbols_to_logits_fn(ids):
      """Go from ids to logits."""
      ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]])

      features["targets"] = ids
      self._coverage = None
      sharded_logits, _, _ = self.model_fn(
          features, False, last_position_only=last_position_only)
      # now self._coverage is a coverage tensor for the first datashard.
      # it has shape [batch_size] and contains floats between 0 and
      # source_length.
      logits = sharded_logits[0]  # Assuming we have one shard.
      if last_position_only:
        return tf.squeeze(logits, axis=[1, 2, 3])
      current_output_position = tf.shape(ids)[1] - 1  # -1 due to the pad above.
      logits = logits[:, current_output_position, :, :]
      return tf.squeeze(logits, axis=[1, 2])

    batch_size = tf.shape(features["inputs"])[0]
    initial_ids = tf.zeros([batch_size], dtype=tf.int32)

    inputs_old = features["inputs"]
    features["inputs"] = tf.expand_dims(features["inputs"], 1)
    if len(features["inputs"].shape) < 5:
      features["inputs"] = tf.expand_dims(features["inputs"], 4)
    # Expand the inputs in to the beam size.
    features["inputs"] = tf.tile(features["inputs"], [1, beam_size, 1, 1, 1])
    s = tf.shape(features["inputs"])
    features["inputs"] = tf.reshape(features["inputs"],
                                    [s[0] * s[1], s[2], s[3], s[4]])

    target_modality = self._hparams.problems[self._problem_idx].target_modality
    vocab_size = target_modality.top_dimensionality
    # Setting decode length to input length + decode_length
    decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length)
    ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids,
                                          beam_size, decode_length, vocab_size,
                                          alpha)

    # Set inputs back to the unexpanded inputs to not to confuse the Estimator!
    features["inputs"] = inputs_old

    # Return `top_beams` decodings (also remove initial id from the beam search)
    return_scores = False  # TODO(lukaszkaiser): make it work multi-problem.
    if top_beams == 1:
      if return_scores:
        return {"outputs": ids[:, 0, 1:], "scores": scores}
      return ids[:, 0, 1:]
    else:
      if return_scores:
        return {"outputs": ids[:, :top_beams, 1:], "scores": scores}
      return ids[:, :top_beams, 1:]