def testNotGreedyBeamTwo(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1) with self.test_session(): ids = final_ids.eval() probs = final_probs.eval() self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids) self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))
def testGreedyWithCornerCase(self): batch_size = 1 beam_size = 1 vocab_size = 3 decode_length = 2 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[0.2, 0.1, 0.7], [0.4, 0.1, 0.5]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_probs, _ = beam_search.beam_search(symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1) with self.test_session(): ids = final_ids.eval() probs = final_probs.eval() self.assertAllEqual([[[0, 2, 2]]], ids) self.assertAllClose([[0.7 * 0.5]], np.exp(probs))
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams): """Sample from the latent space in the autoencoder.""" def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(ids, axis=2) # Ids start with added all-zeros. latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]]) with tf.variable_scope(tf.get_variable_scope(), reuse=False): latents_dense = embed( tf.one_hot(latents_discrete, depth=2**hparams.bottleneck_bits)) latents_pred = decode_transformer(inputs, ed, latents_dense, hparams, "extra") logits = tf.layers.dense(latents_pred, 2**hparams.bottleneck_bits, name="extra_logits") current_output_position = common_layers.shape_list(ids)[1] - 1 logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1]) initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32) length = tf.shape(latents_dense_in)[1] ids, _ = beam_search.beam_search(symbols_to_logits_fn, initial_ids, beam_size=1, decode_length=length, vocab_size=2**hparams.bottleneck_bits, alpha=0.0, eos_id=-1, stop_early=False) res = tf.expand_dims(ids[:, 0, :], axis=2) # Pick first beam. return res[:, 1:] # Remove the added all-zeros from ids.
def testStatesAfterLoop(self): batch_size = 1 beam_size = 1 vocab_size = 2 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]]) def symbols_to_logits(ids, _, states): pos = tf.shape(ids)[1] - 1 logits = tf.to_float(tf.log(probabilities[pos, :])) states["state"] += 1 return logits, states states = { "state": tf.zeros((batch_size, 1)), } states["state"] = tf.placeholder_with_default(states["state"], shape=(None, 1)) _, _, final_states = beam_search.beam_search(symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1, states=states) with self.test_session() as sess: final_states = sess.run(final_states) self.assertAllEqual([[[2]]], final_states["state"])
def testNotGreedyBeamTwo(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1) with self.test_session(): ids = final_ids.eval() probs = final_probs.eval() self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids) self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))
def testNotGreedyBeamTwoWithoutStopEarly(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1, stop_early=False) with self.test_session(): ids = final_ids.eval() probs = final_probs.eval() # given stop_early = False, the algorithm will return all the beams # so we can test all of them here self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids) self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))
def testGreedyWithCornerCase(self): batch_size = 1 beam_size = 1 vocab_size = 3 decode_length = 2 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[0.2, 0.1, 0.7], [0.4, 0.1, 0.5]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1) with self.test_session(): ids = final_ids.eval() probs = final_probs.eval() self.assertAllEqual([[[0, 2, 2]]], ids) self.assertAllClose([[0.7 * 0.5]], np.exp(probs))
def testNotGreedyBeamTwoWithoutStopEarly(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_probs, _ = beam_search.beam_search(symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1, stop_early=False) with self.test_session(): ids = final_ids.eval() probs = final_probs.eval() # given stop_early = False, the algorithm will return all the beams # so we can test all of them here self.assertAllEqual([[[0, 2, 1, 0], [0, 2, 0, 1]]], ids) self.assertAllClose([[0.8 * 0.5, 0.8 * 0.4 * 0.9]], np.exp(probs))
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams): """Sample from the latent space in the autoencoder.""" vocab_size = 2**hparams.z_size beam_size = 1 # TODO(lukaszkaiser): larger beam sizes seem to work bad. inputs = tf.tile(inputs, [beam_size, 1, 1]) ed = tf.tile(ed, [beam_size, 1, 1, 1]) def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(ids, axis=2) # Ids start with added all-zeros. latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]]) with tf.variable_scope(tf.get_variable_scope(), reuse=False): latents_dense = embed(latents_discrete) latents_pred = decode_transformer( inputs, ed, latents_dense, hparams, "extra") logits = tf.layers.dense(latents_pred, vocab_size, name="extra_logits") current_output_position = common_layers.shape_list(ids)[1] - 1 logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1]) initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32) length = tf.shape(latents_dense_in)[1] ids, _ = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, length, vocab_size, alpha=0.0, eos_id=-1, stop_early=False) res = tf.expand_dims(ids[:, 0, :], axis=2) # Pick first beam. return res[:, 1:] # Remove the added all-zeros from ids.
def testGreedyBatchOne(self): batch_size = 1 beam_size = 1 vocab_size = 2 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO # Test that beam search finds the most probable sequence. # These probabilities represent the following search # # G0 (0) # / \ # / \ # / \ # / \ # 0(0.7) 1(0.3) # / \ # / \ # / \ # 0(0.4) 1(0.6) # /\ # / \ # / \ # 0(0.5) 1(0.5) # and the following decoding probabilities # 0000 - 0.7 * 0.4 * 0.1 # 0001 - 0.7 * 0.4 * 0.9 # 001 - 0.7 * 0.6 (Best) # 01 = 0.3 # # 001 is the most likely sequence under these probabilities. probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1) with self.test_session(): ids = final_ids.eval() probs = final_probs.eval() self.assertAllEqual([[[0, 0, 1]]], ids) self.assertAllClose([[0.7 * 0.6]], np.exp(probs))
def beam_search_decode(self, features, hidden_feature, mode, problem_name): # prepare inputs to attention key = 'ori_seq' if self.params.label_transfer else 'seq' encoder_outputs = hidden_feature[key] max_seq_len = self.params.max_seq_len embedding_table = hidden_feature['embed_table'] token_type_ids = features['segment_ids'] num_classes = self.params.num_classes[problem_name] batch_size = modeling.get_shape_list(encoder_outputs, expected_rank=3)[0] hidden_size = self.params.bert_config.hidden_size if self.params.problem_type[problem_name] == 'seq2seq_text': embedding_table = hidden_feature['embed_table'] else: embedding_table = tf.get_variable('tag_embed_table', shape=[num_classes, hidden_size]) symbol_to_logit_fn = self._get_symbol_to_logit_fn( max_seq_len=max_seq_len, embedding_table=embedding_table, token_type_ids=token_type_ids, decoder=self.decoder, num_classes=num_classes, encoder_output=encoder_outputs, input_mask=features['input_mask'], params=self.params) # create cache for fast decode cache = { str(layer): { "key_layer": tf.zeros([batch_size, 0, hidden_size]), "value_layer": tf.zeros([batch_size, 0, hidden_size]), } for layer in range(self.params.decoder_num_hidden_layers) } # cache['encoder_outputs'] = encoder_outputs # cache['encoder_decoder_attention_mask'] = features['input_mask'] initial_ids = tf.zeros([batch_size], dtype=tf.int32) decode_ids, _ = beam_search.beam_search( symbols_to_logits_fn=symbol_to_logit_fn, initial_ids=initial_ids, states=cache, vocab_size=self.params.num_classes[problem_name], beam_size=self.params.beam_size, alpha=self.params.beam_search_alpha, decode_length=self.params.decode_max_seq_len, eos_id=self.params.eos_id[problem_name]) # Get the top sequence for each batch element top_decoded_ids = decode_ids[:, 0, 1:] self.prob = top_decoded_ids return self.prob
def testGreedyBatchOne(self): batch_size = 1 beam_size = 1 vocab_size = 2 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO # Test that beam search finds the most probable sequence. # These probabilities represent the following search # # G0 (0) # / \ # / \ # / \ # / \ # 0(0.7) 1(0.3) # / \ # / \ # / \ # 0(0.4) 1(0.6) # /\ # / \ # / \ # 0(0.5) 1(0.5) # and the following decoding probabilities # 0000 - 0.7 * 0.4 * 0.1 # 0001 - 0.7 * 0.4 * 0.9 # 001 - 0.7 * 0.6 (Best) # 01 = 0.3 # # 001 is the most likely sequence under these probabilities. probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1) with self.test_session(): ids = final_ids.eval() probs = final_probs.eval() self.assertAllEqual([[[0, 0, 1]]], ids) self.assertAllClose([[0.7 * 0.6]], np.exp(probs))
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams): """Samples from the latent space in the autoencoder. Args: latents_dense_in: Tensor of shape [batch, length_q, ...]. Only the shape of its first two dimensions are used. length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). inputs: Tensor of shape [batch, length_kv, hparams.hidden_size]. Encodings to attend to in decoder. ed: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. embed: Callable which embeds discrete latent hot-vectors and a hidden size and returns dense vectors. hparams: HParams. Returns: Tensor of shape [batch, length]. """ def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(ids, axis=2) # Ids start with added all-zeros. latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]]) with tf.variable_scope(tf.get_variable_scope(), reuse=False): latents_dense = embed( tf.one_hot(latents_discrete, depth=2**hparams.bottleneck_bits), hparams.hidden_size) latents_pred = transformer_latent_decoder(latents_dense, inputs, ed, hparams, name="latent_prediction") logits = tf.layers.dense(latents_pred, 2**hparams.bottleneck_bits, name="logits_dense") current_output_position = common_layers.shape_list(ids)[1] - 1 logits = logits[:, current_output_position, :] return logits initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32) length = tf.shape(latents_dense_in)[1] ids, _, _ = beam_search.beam_search(symbols_to_logits_fn, initial_ids, 1, length, 2**hparams.bottleneck_bits, alpha=0.0, eos_id=-1, stop_early=False) res = tf.expand_dims(ids[:, 0, :], axis=2) # Pick first beam. return res[:, 1:] # Remove the added all-zeros from ids.
def ae_latent_sample_beam(latents_dense_in, inputs, ed, embed, hparams): """Samples from the latent space in the autoencoder. Args: latents_dense_in: Tensor of shape [batch, length_q, ...]. Only the shape of its first two dimensions are used. length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). inputs: Tensor of shape [batch, length_kv, hparams.hidden_size]. Encodings to attend to in decoder. ed: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. embed: Callable which embeds discrete latent hot-vectors and a hidden size and returns dense vectors. hparams: tf.contrib.training.HParams. Returns: Tensor of shape [batch, length]. """ def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(ids, axis=2) # Ids start with added all-zeros. latents_discrete = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0]]) with tf.variable_scope(tf.get_variable_scope(), reuse=False): latents_dense = embed( tf.one_hot(latents_discrete, depth=2**hparams.bottleneck_bits), hparams.hidden_size) latents_pred = transformer_latent_decoder( latents_dense, inputs, ed, hparams, name="latent_prediction") logits = tf.layers.dense( latents_pred, 2**hparams.bottleneck_bits, name="logits_dense") current_output_position = common_layers.shape_list(ids)[1] - 1 logits = logits[:, current_output_position, :] return logits initial_ids = tf.zeros([tf.shape(latents_dense_in)[0]], dtype=tf.int32) length = tf.shape(latents_dense_in)[1] ids, _ = beam_search.beam_search( symbols_to_logits_fn, initial_ids, 1, length, 2**hparams.bottleneck_bits, alpha=0.0, eos_id=-1, stop_early=False) res = tf.expand_dims(ids[:, 0, :], axis=2) # Pick first beam. return res[:, 1:] # Remove the added all-zeros from ids.
def testTPUBeam(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) # The top beam is always selected so we should see the top beam's state # at each position, which is the one thats getting 3 added to it each step. expected_states = tf.constant([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]]) def symbols_to_logits(_, i, states): # We have to assert the values of state inline here since we can't fetch # them out of the loop! with tf.control_dependencies( [tf.assert_equal(states["state"], expected_states[i])]): logits = tf.to_float(tf.log(probabilities[i, :])) states["state"] += tf.constant([[3.], [7.]]) return logits, states states = { "state": tf.zeros((batch_size, 1)), } states["state"] = tf.placeholder_with_default( states["state"], shape=(None, 1)) final_ids, _ = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 3.5, eos_id=1, states=states, use_tpu=True) with self.test_session() as sess: # Catch and fail so that the testing framework doesn't think it's an error try: sess.run(final_ids) except tf.errors.InvalidArgumentError as e: raise AssertionError(e.message) self.assertAllEqual([[[0, 2, 0, 1], [0, 2, 1, 0]]], final_ids)
def testStateBeamTwo(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) # The top beam is always selected so we should see the top beam's state # at each position, which is the one thats getting 3 added to it each step. expected_states = tf.constant([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]]) def symbols_to_logits(ids, _, states): pos = tf.shape(ids)[1] - 1 # We have to assert the values of state inline here since we can't fetch # them out of the loop! with tf.control_dependencies( [tf.assert_equal(states["state"], expected_states[pos])]): logits = tf.to_float(tf.log(probabilities[pos, :])) states["state"] += tf.constant([[3.], [7.]]) return logits, states states = { "state": tf.zeros((batch_size, 1)), } states["state"] = tf.placeholder_with_default( states["state"], shape=(None, 1)) final_ids, _, _ = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1, states=states) with self.test_session() as sess: # Catch and fail so that the testing framework doesn't think it's an error try: sess.run(final_ids) except tf.errors.InvalidArgumentError as e: raise AssertionError(e.message)
def beam_search_decoding(length=1000): initial_ids = tf.constant(char2idx['<start>'], tf.int32, [1]) def symbols_to_logits(ids): logits = model.forward(ids) return logits[:, tf.shape(ids)[1] - 1, :] final_ids, final_probs, _ = beam_search.beam_search( symbols_to_logits, initial_ids, 5, length, len(char2idx), 0.0, eos_id=char2idx['<end>']) return final_ids[0, 0, :]
def beam_search_decoding(length=20, beam_width=5): initial_ids = tf.fill([model.batch_size], GO) def symbols_to_logits(ids): x = tf.contrib.seq2seq.tile_batch(model.X, beam_width) logits = model.forward(x, ids, reuse=True) return logits[:, tf.shape(ids)[1] - 1, :] final_ids, final_probs, _ = beam_search.beam_search(symbols_to_logits, initial_ids, beam_width, length, len(vocab2id), 0.0, eos_id=EOS) return final_ids
def testNotGreedyBatchTwoBeamTwoWithAlpha(self): batch_size = 2 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO # Probabilities for position * batch * beam * vocab # Probabilities have been set such that with alpha = 3.5, the less probable # but longer sequence will have a better score than the shorter sequence # with higher log prob in batch 1, and the order will be reverse in batch # 2. That is, the shorter sequence will still have a higher score in spite # of the length penalty probabilities = tf.constant([[[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]], [[[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.3, 0.6, 0.1], [0.2, 0.4, 0.4]]], [[[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_scores = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 3.5, eos_id=1) with self.test_session(): ids = final_ids.eval() scores = final_scores.eval() self.assertAllEqual([[[0, 2, 0, 1], [0, 2, 1, 0]], [[0, 2, 1, 0], [0, 2, 0, 1]]], ids) self.assertAllClose([[ np.log(0.8 * 0.4 * 0.9) / (8. / 6.)**3.5, np.log(0.8 * 0.5) / (7. / 6.)**3.5 ], [ np.log(0.8 * 0.6) / (7. / 6.)**3.5, np.log(0.8 * 0.3 * 0.9) / (8. / 6.)**3.5 ]], scores)
def testNotGreedyBatchTwoBeamTwoWithAlpha(self): batch_size = 2 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO # Probabilities for position * batch * beam * vocab # Probabilities have been set such that with alpha = 3.5, the less probable # but longer sequence will have a better score than the shorter sequence # with higher log prob in batch 1, and the order will be reverse in batch # 2. That is, the shorter sequence will still have a higher score in spite # of the length penalty probabilities = tf.constant([[[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]], [[[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.3, 0.6, 0.1], [0.2, 0.4, 0.4]]], [[[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_scores = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 3.5, eos_id=1) with self.test_session(): ids = final_ids.eval() scores = final_scores.eval() self.assertAllEqual([[[0, 2, 0, 1], [0, 2, 1, 0]], [[0, 2, 1, 0], [0, 2, 0, 1]]], ids) self.assertAllClose([[ np.log(0.8 * 0.4 * 0.9) / (8. / 6.)**3.5, np.log(0.8 * 0.5) / (7. / 6.)**3.5 ], [ np.log(0.8 * 0.6) / (7. / 6.)**3.5, np.log(0.8 * 0.3 * 0.9) / (8. / 6.)**3.5 ]], scores)
def predict(model, inputs, inpf, tarf, bos_id, eos_id, beam_size, vocab_size, alpha=1.0, decode_length=40): """ inputs: already int encoded set of inputs, [batch_size, ?], tf.int32 """ batch_size = inputs.shape[0] initial_ids = [bos_id] * batch_size enc_input = tf.expand_dims(inputs, 1) enc_input = tf.tile(enc_input, [1, beam_size, 1]) enc_input = tf.reshape(enc_input, [batch_size * beam_size, -1]) def symbols_to_logits(ids): logits = model([ enc_input, tf.tile(tf.expand_dims(inpf, 0), [tf.shape(ids)[0], 1]), ids, tf.tile(tf.expand_dims(tarf, 0), [tf.shape(ids)[0], 1]), ]) logits = logits[0][:, -1, :] return logits x = beam_search(symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, alpha=alpha, eos_id=eos_id) ids = x[0] probs = x[1] return ids, probs
def testStates(self): batch_size = 1 beam_size = 1 vocab_size = 2 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]]) expected_states = tf.constant([[[0.]], [[1.]]]) def symbols_to_logits(ids, _, states): pos = tf.shape(ids)[1] - 1 # We have to assert the values of state inline here since we can't fetch # them out of the loop! with tf.control_dependencies( [tf.assert_equal(states["state"], expected_states[pos])]): logits = tf.to_float(tf.log(probabilities[pos, :])) states["state"] += 1 return logits, states states = { "state": tf.zeros((batch_size, 1)), } states["state"] = tf.placeholder_with_default( states["state"], shape=(None, 1)) final_ids, _, _ = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1, states=states) with self.test_session() as sess: # Catch and fail so that the testing framework doesn't think it's an error try: sess.run(final_ids) except tf.errors.InvalidArgumentError as e: raise AssertionError(e.message)
def testStates(self): batch_size = 1 beam_size = 1 vocab_size = 2 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]]) expected_states = tf.constant([[[0.]], [[1.]]]) def symbols_to_logits(ids, _, states): pos = tf.shape(ids)[1] - 1 # We have to assert the values of state inline here since we can't fetch # them out of the loop! with tf.control_dependencies( [tf.assert_equal(states["state"], expected_states[pos])]): logits = tf.to_float(tf.log(probabilities[pos, :])) states["state"] += 1 return logits, states states = { "state": tf.zeros((batch_size, 1)), } states["state"] = tf.placeholder_with_default( states["state"], shape=(None, 1)) final_ids, _ = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1, states=states) with self.test_session() as sess: # Catch and fail so that the testing framework doesn't think it's an error try: sess.run(final_ids) except tf.errors.InvalidArgumentError as e: raise AssertionError(e.message)
def testShapes(self): batch_size = 2 beam_size = 3 vocab_size = 4 decode_length = 10 initial_ids = tf.constant([0, 0]) # GO def symbols_to_logits(_): # Just return random logits return tf.random_uniform((batch_size * beam_size, vocab_size)) final_ids, final_probs = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.) self.assertEqual(final_ids.get_shape().as_list(), [None, beam_size, None]) self.assertEqual(final_probs.get_shape().as_list(), [batch_size, beam_size])
def testShapes(self): batch_size = 2 beam_size = 3 vocab_size = 4 decode_length = 10 initial_ids = tf.constant([0, 0]) # GO def symbols_to_logits(_): # Just return random logits return tf.random_uniform((batch_size * beam_size, vocab_size)) final_ids, final_probs = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.) self.assertEqual(final_ids.get_shape().as_list(), [None, beam_size, None]) self.assertEqual(final_probs.get_shape().as_list(), [batch_size, beam_size])
def forward(self, inputs, labels, masks, training): b = tf.shape(labels)[0] label_inputs = tf.concat([tf.zeros((b, 1), tf.int64), labels[:, :-1]], axis=1) mask_inputs = tf.concat([tf.ones((b, 1), tf.int32), masks[:, :-1]], axis=1) emb_inputs = tf.nn.embedding_lookup(self.embedding, label_inputs) rnn_masks = tf.cast(tf.expand_dims(mask_inputs, 2), tf.float32) rnn_inputs = tf.multiply(emb_inputs, rnn_masks) self.dropout.apply(rnn_inputs, training=training) h0 = self.fc.apply(inputs) rnn_outputs = self.decoder.apply(rnn_inputs, initial_state=[h0, h0]) self.dropout.apply(rnn_outputs, training=training) logits = tf.matmul(rnn_outputs, self.embedding, transpose_b=True) loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits) loss = tf.reduce_sum(loss * tf.cast(masks, tf.float32), axis=-1) loss = tf.reduce_mean(loss) def pred_fn(x, i, states): e = tf.nn.embedding_lookup(self.embedding, x) r = self.decoder.apply(e, initial_state=states)[:, -1, :] o = tf.matmul(r, self.embedding, transpose_b=True) return o, states initial_ids = tf.ones((b, ), tf.int32) * self.sos_id beam_preds, _, _ = beam_search.beam_search( pred_fn, initial_ids, alpha=0., beam_size=self.beam_size, decode_length=self.seq_len, vocab_size=self.output_size + 1, eos_id=self.eos_id, states=[h0, h0]) return beam_preds[:, 0, 1:], loss
def testNotGreedyBeamTwoWithStopEarly(self): batch_size = 1 beam_size = 2 vocab_size = 3 decode_length = 3 initial_ids = tf.constant([0] * batch_size) # GO probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) def symbols_to_logits(ids): pos = tf.shape(ids)[1] logits = tf.to_float(tf.log(probabilities[pos - 1, :])) return logits final_ids, final_probs = beam_search.beam_search( symbols_to_logits, initial_ids, beam_size, decode_length, vocab_size, 0.0, eos_id=1, stop_early=True) # default value, but just to make this explicit with self.test_session(): ids = final_ids.eval() probs = final_probs.eval() # given stop_early = True, the only 'assurance' is w.r.t. the first beam # (i.e., other beams may not even be completed) # so, we check only the first beam first_beam = ids[:, 0] first_probs = probs[:, 0] self.assertAllEqual([[0, 2, 1]], first_beam) self.assertAllClose([0.8 * 0.5], np.exp(first_probs))
def _beam_decode_slow(self, features, decode_length, beam_size, top_beams, alpha): """Slow version of Beam search decoding. Quadratic time in decode_length. Args: features: an map of string to `Tensor` decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search """ batch_size = common_layers.shape_list(features["inputs"])[0] batch_size = tf.Print(batch_size, [batch_size], "beam_decode batch_size=") def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]]) if "partial_targets" in features: pt = features["partial_targets"] pt_length = common_layers.shape_list(pt)[1] pt = tf.tile(pt, [1, beam_size]) pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1]) ids = tf.concat([pt, ids], axis=1) features["targets"] = ids self._coverage = None logits, _ = self(features) # pylint: disable=not-callable # now self._coverage is a coverage tensor for the first datashard. # it has shape [batch_size] and contains floats between 0 and # source_length. modality = self.hparams.problems[self._problem_idx].target_modality if modality.top_is_pointwise: return tf.squeeze(logits, axis=[1, 2, 3]) # -1 due to the pad above. current_output_position = common_layers.shape_list(ids)[1] - 1 logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1, 2]) initial_ids = tf.zeros([batch_size], dtype=tf.int32) if self.has_input: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 1) if len(features["inputs"].shape) < 5: features["inputs"] = tf.expand_dims(features["inputs"], 4) # Expand the inputs in to the beam size. features["inputs"] = tf.tile(features["inputs"], [1, beam_size, 1, 1, 1]) s = common_layers.shape_list(features["inputs"]) features["inputs"] = tf.reshape(features["inputs"], [s[0] * s[1], s[2], s[3], s[4]]) target_modality = self.hparams.problems[self._problem_idx].target_modality vocab_size = target_modality.top_dimensionality # Setting decode length to input length + decode_length decode_length = tf.constant(decode_length) if "partial_targets" not in features: decode_length += common_layers.shape_list(features["inputs"])[1] ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, stop_early=(top_beams == 1)) # Set inputs back to the unexpanded inputs to not to confuse the Estimator! if self.has_input: features["inputs"] = inputs_old # Return `top_beams` decodings (also remove initial id from the beam search) return_scores = True # TODO(lukaszkaiser): make it work multi-problem. if top_beams == 1: if return_scores: return {"outputs": ids[:, 0, 1:], "scores": scores} return ids[:, 0, 1:] else: if return_scores: return {"outputs": ids[:, :top_beams, 1:], "scores": scores} return ids[:, :top_beams, 1:]
def _beam_decode_slow(self, features, decode_length, beam_size, top_beams, last_position_only, alpha): """Slow version of Beam search decoding. Quadratic time in decode_length. Args: features: an map of string to `Tensor` decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. last_position_only: a boolean, speed-up by computing last position only. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search """ batch_size = tf.shape(features["inputs"])[0] batch_size = tf.Print(batch_size, [batch_size], "beam_decode batch_size=") def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]]) if "partial_targets" in features: pt = features["partial_targets"] pt_length = tf.shape(pt)[1] pt = tf.tile(pt, [1, beam_size]) pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1]) ids = tf.concat([pt, ids], axis=1) features["targets"] = ids self._coverage = None sharded_logits, _ = self.model_fn( features, False, last_position_only=last_position_only) # now self._coverage is a coverage tensor for the first datashard. # it has shape [batch_size] and contains floats between 0 and # source_length. logits = sharded_logits[0] # Assuming we have one shard. if last_position_only: return tf.squeeze(logits, axis=[1, 2, 3]) current_output_position = tf.shape(ids)[1] - 1 # -1 due to the pad above. logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1, 2]) initial_ids = tf.zeros([batch_size], dtype=tf.int32) if self.has_input: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 1) if len(features["inputs"].shape) < 5: features["inputs"] = tf.expand_dims(features["inputs"], 4) # Expand the inputs in to the beam size. features["inputs"] = tf.tile(features["inputs"], [1, beam_size, 1, 1, 1]) s = tf.shape(features["inputs"]) features["inputs"] = tf.reshape(features["inputs"], [s[0] * s[1], s[2], s[3], s[4]]) target_modality = self._hparams.problems[self._problem_idx].target_modality vocab_size = target_modality.top_dimensionality # Setting decode length to input length + decode_length decode_length = tf.constant(decode_length) if "partial_targets" not in features: decode_length += tf.shape(features["inputs"])[1] ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha) # Set inputs back to the unexpanded inputs to not to confuse the Estimator! if self.has_input: features["inputs"] = inputs_old # Return `top_beams` decodings (also remove initial id from the beam search) return_scores = False # TODO(lukaszkaiser): make it work multi-problem. if top_beams == 1: if return_scores: return {"outputs": ids[:, 0, 1:], "scores": scores} return ids[:, 0, 1:] else: if return_scores: return {"outputs": ids[:, :top_beams, 1:], "scores": scores} return ids[:, :top_beams, 1:]
def fast_decode(encoder_output, encoder_decoder_attention_bias, symbols_to_logits_fn, hparams, decode_length, vocab_size, beam_size=1, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID, batch_size=None): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: encoder_output: Output from encoder. encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids, step, cache)` to symbol logits. hparams: run hyperparameters decode_length: an integer. How many additional timesteps to decode. vocab_size: Output vocabulary size. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. eos_id: End-of-sequence symbol in beam search. batch_size: an integer scalar - must be passed if there is no input Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if top_beams == 1 or [batch_size, top_beams, <= decode_length] otherwise "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If beam size > 1 with partial targets. """ if encoder_output is not None: batch_size = common_layers.shape_list(encoder_output)[0] key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } if encoder_output is not None: cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, eos_id=eos_id, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] """ t2t_csaky code """ # do roulette wheel selection or inverse roulette wheel selection if hparams.roulette == "Normal" or hparams.roulette == "Inverse": if hparams.roulette == "Normal": probabilities = tf.pow(tf.constant(2.0), scores) start = 0 else: probabilities = tf.subtract(tf.constant(1.0), tf.pow(tf.constant(2.0), scores)) start = beam_size - hparams.roulette_beam_size ex_probs = tf.divide(probabilities, tf.reduce_sum(probabilities)) #ex_probs=tf.nn.softmax(probabilities) # sample a number between 0 and 1 wheel = tf.random_uniform([1]) upper_bound = tf.constant(0.0) # change this as well if using inverse for i in range(start, hparams.roulette_beam_size): upper_bound = tf.add(ex_probs[:, i], upper_bound) truthValue = tf.squeeze( tf.logical_and(wheel >= upper_bound - ex_probs[:, i], wheel <= upper_bound)) decoded_ids, scores, i = tf.cond( truthValue, lambda: (decoded_ids[:, i, :], scores[:, i], beam_size), lambda: (decoded_ids, scores, i)) else: # Greedy def inner_loop(i, finished, next_id, decoded_ids, cache): """One step of greedy decoding.""" logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature( logits, temperature) finished |= tf.equal(next_id, eos_id) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, finished, next_id, decoded_ids, cache def is_not_finished(i, finished, *_): return (i < decode_length) & tf.logical_not( tf.reduce_all(finished)) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) finished = tf.fill([batch_size], False) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, _, decoded_ids, _ = tf.while_loop( is_not_finished, inner_loop, [tf.constant(0), finished, next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(beam_search.get_state_shape_invariants, cache), ]) scores = None return { "outputs": decoded_ids, "encoder_outputs": encoder_output, "scores": scores }
def transformer_beam_search(self, encoder_outputs, encoder_attn_bias, encoder_embed_inputs_list, sentence_complex_input_placeholder, emb_simple, w, b, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, obj, obj_tensors): # Use Beam Search in evaluation stage # Update [a, b, c] to [a, a, a, b, b, b, c, c, c] if beam_search_size == 3 encoder_beam_outputs = tf.concat( [tf.tile(tf.expand_dims(encoder_outputs[o, :, :], axis=0), [self.model_config.beam_search_size, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) encoder_embed_inputs = tf.stack(encoder_embed_inputs_list, axis=1) encoder_beam_embed_inputs = tf.concat( [tf.tile(tf.expand_dims(encoder_embed_inputs[o, :, :], axis=0), [self.model_config.beam_search_size, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) encoder_attn_beam_bias = tf.concat( [tf.tile(tf.expand_dims(encoder_attn_bias[o, :, :, :], axis=0), [self.model_config.beam_search_size, 1, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) if 'direct' in self.model_config.memory: obj_tensors['direct_bert_output_bak'] = obj_tensors['direct_bert_output'] obj_tensors['direct_bert_bias_bak'] = obj_tensors['direct_bert_bias'] obj_tensors['direct_bert_output'] = tf.concat( [tf.tile(tf.expand_dims(obj_tensors['direct_bert_output'][o, :, :], axis=0), [self.model_config.beam_search_size, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) obj_tensors['direct_bert_bias'] = tf.concat( [tf.tile(tf.expand_dims(obj_tensors['direct_bert_bias'][o, :, :, :], axis=0), [self.model_config.beam_search_size, 1, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) if self.model_config.subword_vocab_size or 'bert_token' in self.model_config.bert_mode: go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)[0] eos_id = self.data.vocab_simple.encode(constant.SYMBOL_END)[0] else: go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO) eos_id = self.data.vocab_simple.encode(constant.SYMBOL_END) batch_go = tf.expand_dims(tf.tile( tf.expand_dims(self.embedding_fn(go_id, emb_simple), axis=0), [self.model_config.batch_size, 1]), axis=1) batch_go_beam = tf.concat( [tf.tile(tf.expand_dims(batch_go[o, :, :], axis=0), [self.model_config.beam_search_size, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) def symbol_to_logits_fn(ids): cur_ids = ids[:, 1:] embs = tf.nn.embedding_lookup(emb_simple, cur_ids) embs = tf.concat([batch_go_beam, embs], axis=1) final_outputs, _, _ = self.decode_inputs_to_outputs(embs, encoder_beam_outputs, encoder_attn_beam_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, obj_tensors=obj_tensors) decoder_logit_list = self.output_to_logit(final_outputs[:, -1, :], w, b) if self.model_config.pointer_mode: segment_mask = None if 'line_comp_segids' in obj: segment_mask = obj['line_comp_segids'] decoder_logit_list = word_distribution( [decoder_logit_list], [final_outputs[:, -1, :]], encoder_beam_outputs, encoder_beam_embed_inputs, sentence_complex_input_placeholder, obj_tensors, self.model_config, self.data, segment_mask, is_test=True) return decoder_logit_list beam_ids, beam_score = beam_search.beam_search(symbol_to_logits_fn, tf.ones([self.model_config.batch_size], tf.int32) * go_id, self.model_config.beam_search_size, self.model_config.max_simple_sentence, self.data.vocab_simple.vocab_size(), self.model_config.penalty_alpha, eos_id=eos_id ) top_beam_ids = beam_ids[:, 0, 1:] top_beam_ids = tf.pad(top_beam_ids, [[0, 0], [0, self.model_config.max_simple_sentence - tf.shape(top_beam_ids)[1]]]) decoder_target_list = [tf.squeeze(d, 1) for d in tf.split(top_beam_ids, self.model_config.max_simple_sentence, axis=1)] decoder_score = -beam_score[:, 0] / tf.to_float(tf.shape(top_beam_ids)[1]) # Get outputs based on target ids decode_input_embs = tf.stack(self.embedding_fn(decoder_target_list, emb_simple), axis=1) tf.get_variable_scope().reuse_variables() if 'direct' in self.model_config.memory: obj_tensors['direct_bert_output'] = obj_tensors['direct_bert_output_bak'] obj_tensors['direct_bert_bias'] = obj_tensors['direct_bert_bias_bak'] final_outputs, decoder_outputs, _ = self.decode_inputs_to_outputs(decode_input_embs, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, obj_tensors=obj_tensors) output = ModelOutput( encoder_outputs=encoder_outputs, final_outputs_list=final_outputs, decoder_outputs_list=decoder_outputs, decoder_score=decoder_score, decoder_target_list=decoder_target_list, encoder_embed_inputs_list=encoder_embed_inputs_list ) return output
def fast_decode(encoder_output, encoder_decoder_attention_bias, symbols_to_logits_fn, hparams, decode_length, vocab_size, beam_size=1, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID, batch_size=None, sentence_cache=None, cache_flag=None): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: encoder_output: Output from encoder. encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids, step, cache)` to symbol logits. hparams: run hyperparameters decode_length: an integer. How many additional timesteps to decode. vocab_size: Output vocabulary size. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. eos_id: End-of-sequence symbol in beam search. batch_size: an integer scalar - must be passed if there is no input Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if top_beams == 1 or [batch_size, top_beams, <= decode_length] otherwise "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If beam size > 1 with partial targets. """ if encoder_output is not None: batch_size = common_layers.shape_list(encoder_output)[0] key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } if encoder_output is not None: cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( lambda x: symbols_to_logits_fn(x)[:-1], initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, eos_id=eos_id, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(cache_flag, i, finished, next_id, decoded_ids, cache): """One step of greedy decoding.""" logits, cache, out = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature( logits, temperature) finished |= tf.equal(next_id, eos_id) next_id = tf.expand_dims(next_id, axis=1) cache_flag = tf.py_func(sentence_cache.AddMultipleEntries, [next_id, out], tf.int64) cache_flag.set_shape(tf.TensorShape([])) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return cache_flag, i + 1, finished, next_id, decoded_ids, cache def is_not_finished(cache_flag, i, finished, *_): return (i < decode_length) & tf.logical_not( tf.reduce_all(finished)) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) finished = tf.fill([batch_size], False) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) cache_flag, _, _, _, decoded_ids, _ = tf.while_loop( is_not_finished, inner_loop, [ cache_flag, tf.constant(0), finished, next_id, decoded_ids, cache ], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(beam_search.get_state_shape_invariants, cache), ]) scores = None return {"outputs": decoded_ids + cache_flag, "scores": scores}
def fast_decode(encoder_output, encoder_decoder_attention_bias, symbols_to_logits_fn, hparams, decode_length, vocab_size, beam_size=1, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID, batch_size=None): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: encoder_output: Output from encoder. encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids, step, cache)` to symbol logits. hparams: run hyperparameters decode_length: an integer. How many additional timesteps to decode. vocab_size: Output vocabulary size. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. eos_id: End-of-sequence symbol in beam search. batch_size: an integer scalar - must be passed if there is no input Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if top_beams == 1 or [batch_size, top_beams, <= decode_length] otherwise "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If beam size > 1 with partial targets. """ if encoder_output is not None: batch_size = common_layers.shape_list(encoder_output)[0] key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } if encoder_output is not None: cache["encoder_output"] = encoder_output cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, eos_id=eos_id, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(i, finished, next_id, decoded_ids, cache): """One step of greedy decoding.""" logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature(logits, temperature) finished |= tf.equal(next_id, eos_id) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, finished, next_id, decoded_ids, cache def is_not_finished(i, finished, *_): return (i < decode_length) & tf.logical_not(tf.reduce_all(finished)) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) finished = tf.fill([batch_size], False) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, _, decoded_ids, _ = tf.while_loop( is_not_finished, inner_loop, [tf.constant(0), finished, next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(beam_search.get_state_shape_invariants, cache), ]) scores = None return {"outputs": decoded_ids, "scores": scores}
def _beam_decode_slow(self, features, decode_length, beam_size, top_beams, alpha): """Slow version of Beam search decoding. Quadratic time in decode_length. Args: features: an map of string to `Tensor` decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search """ beam_size = 1 if use_bottom_up_features: features["inputs"] = features["bottom_up_features"] batch_size = common_layers.shape_list(features["inputs"])[0] def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]]) if "partial_targets" in features: pt = features["partial_targets"] pt_length = common_layers.shape_list(pt)[1] pt = tf.tile(pt, [1, beam_size]) pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1]) ids = tf.concat([pt, ids], axis=1) features["targets"] = ids self._coverage = None logits, _ = self(features) # pylint: disable=not-callable # now self._coverage is a coverage tensor for the first datashard. # it has shape [batch_size] and contains floats between 0 and # source_length. if self._problem_hparams: modality = self._problem_hparams.target_modality if modality.top_is_pointwise: return tf.squeeze(logits, axis=[1, 2, 3]) # -1 due to the pad above. current_output_position = common_layers.shape_list(ids)[1] - 1 logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1, 2]) initial_ids = tf.zeros([batch_size], dtype=tf.int32) if self.has_input: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 1) if len(features["inputs"].shape) < 5: features["inputs"] = tf.expand_dims(features["inputs"], 4) # Expand the inputs in to the beam size. features["inputs"] = tf.tile(features["inputs"], [1, beam_size, 1, 1, 1]) s = common_layers.shape_list(features["inputs"]) features["inputs"] = tf.reshape(features["inputs"], [s[0] * s[1], s[2], s[3], s[4]]) features["bottom_up_features"] = features["inputs"] target_modality = self._problem_hparams.target_modality vocab_size = target_modality.top_dimensionality # Setting decode length to input length + decode_length decode_length = tf.constant(decode_length) if "partial_targets" not in features: decode_length += common_layers.shape_list(features["inputs"])[1] ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, stop_early=(top_beams == 1)) # Set inputs back to the unexpanded inputs to not to confuse the Estimator! if self.has_input: features["inputs"] = inputs_old features["bottom_up_features"] = inputs_old # Return `top_beams` decodings (also remove initial id from the beam search) # TODO(lukaszkaiser): make it work multi-problem. if top_beams == 1: samples = ids[:, 0, 1:] else: samples = ids[:, :top_beams, 1:] return {"outputs": samples, "scores": scores}
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = common_layers.shape_list(inputs)[0] target_modality = self._problem_hparams.target_modality if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache, nonpadding=_features_to_nonpadding(features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access if not context.in_eager_mode(): for layer in cache: cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims( common_layers.sample_with_temperature(logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) scores = None next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores
def transformer_beam_search(self, abstr_outputs, abstr_bias, emb_kword, proj_w, proj_b, hist_vector=None): # Use Beam Search in evaluation stage # Update [a, b, c] to [a, a, a, b, b, b, c, c, c] if beam_search_size == 3 encoder_beam_outputs = tf.concat( [tf.tile(tf.expand_dims(abstr_outputs[o, :, :], axis=0), [self.model_config.beam_search_size, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) encoder_attn_beam_bias = tf.concat( [tf.tile(tf.expand_dims(abstr_bias[o, :, :, :], axis=0), [self.model_config.beam_search_size, 1, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) hist_beam_vector = tf.concat( [tf.tile(tf.expand_dims(hist_vector[o, :, :], axis=0), [self.model_config.beam_search_size, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) if self.model_config.subword_vocab_size: go_id = self.voc_kword.encode(constant.SYMBOL_GO)[0] else: go_id = self.voc_kword.encode(constant.SYMBOL_GO) batch_go = tf.expand_dims(tf.tile( tf.expand_dims(self.embedding_fn(go_id, emb_kword), axis=0), [self.model_config.batch_size, 1]), axis=1) batch_go_beam = tf.concat( [tf.tile(tf.expand_dims(batch_go[o, :, :], axis=0), [self.model_config.beam_search_size, 1, 1]) for o in range(self.model_config.batch_size)], axis=0) def symbol_to_logits_fn(ids): cur_ids = ids[:, 1:] embs = tf.nn.embedding_lookup(emb_kword, cur_ids) embs = tf.concat([batch_go_beam, embs], axis=1) final_outputs = self.decode_inputs_to_outputs( embs, encoder_beam_outputs, encoder_attn_beam_bias, hist_vector=hist_beam_vector) return self.output_to_logit(final_outputs[:, -1, :], proj_w, proj_b) beam_ids, beam_score = beam_search.beam_search( symbol_to_logits_fn, tf.zeros([self.model_config.batch_size], tf.int32), self.model_config.beam_search_size, self.model_config.max_kword_len, self.voc_kword.vocab_size(), 0.6 ) top_beam_ids = beam_ids[:, 0, 1:] top_beam_ids = tf.pad(top_beam_ids, [[0, 0], [0, self.model_config.max_kword_len - tf.shape(top_beam_ids)[1]]]) decoder_target_list = [tf.squeeze(d, 1) for d in tf.split(top_beam_ids, self.model_config.max_kword_len, axis=1)] decoder_score = -beam_score[:, 0] / tf.to_float(tf.shape(top_beam_ids)[1]) return decoder_score, top_beam_ids
def decode_beam_search(self, start_ids, eos_id, pad_id, enc_output, enc_mask, scope="model"): batch_size = tf.shape(start_ids)[0] cache = { # pylint: disable=g-complex-comprehension "layer_%d" % layer: { "uniform_avg": tf.zeros([batch_size, 1, self.model_dimension]), } for layer in range(self.num_layers) } cache["logits"] = tf.zeros([batch_size, 0, self.vocabulary_size]) pos_indices = tf.range(self.max_dec_time_step, dtype=tf.int32) pos_indices = tf.reshape(pos_indices, [1, -1]) pos_values = self.positional_embedding(pos_indices) def beam_search_tile(output, tile_pattern, final_shape): x = tf.tile(output, tile_pattern) x = tf.reshape(x, final_shape) return x enc_output_feature_dim = enc_output.get_shape().as_list()[2] enc_output = beam_search_tile( enc_output, [1, self.beam_size, 1], [batch_size * self.beam_size, -1, enc_output_feature_dim]) enc_mask = beam_search_tile(enc_mask, [1, self.beam_size], [batch_size * self.beam_size, -1]) def symbols_to_logits_fn(ids, step, cache): """Looks up ids to logits.""" logging.info( "Running symbols to logits. ids=%s, step=%s, cache=%s", ids, step, cache) curr_id = ids[:, -1:] with tf.name_scope(scope): curr_embed = self.embedding(curr_id) input_mask = tf.ones(tf.shape(curr_embed)[:-1], dtype=tf.float32) if self.embedding_size != self.model_dimension: curr_embed = self.input_bottleneck(curr_embed, input_mask) inputs = self.qact( self.ln(curr_embed + pos_values[:, step:step + 1, :])) layer_out = self.transformer_uniform_attn_decoder(inputs, input_mask, enc_output, enc_mask, step=step + 1, cache=cache) next_logits, _ = self.model_outputs(layer_out) cache["logits"] = tf.concat([cache["logits"], next_logits], axis=1) return next_logits, cache self.finished_seq, self.finished_scores, states = beam_search.beam_search( symbols_to_logits_fn, initial_ids=start_ids, beam_size=self.beam_size, decode_length=self.max_dec_time_step, vocab_size=self.vocabulary_size, alpha=0.6, eos_id=eos_id, states=cache) beam_ids = self.finished_seq[:, 0, 1:] beam_ids = tf.pad( beam_ids, [[0, 0], [0, self.max_dec_time_step - tf.shape(beam_ids)[1]]], constant_values=pad_id) logits = states["logits"][:, 0, :, :] logits = tf.pad( logits, [[0, 0], [0, self.max_dec_time_step - tf.shape(logits)[1]], [0, 0]], constant_values=self.parameters.invalid_logit) return logits, beam_ids
def fast_decode(symbols_to_logits_fn, hparams, decode_length, vocab_size, beam_size=1, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID, batch_size=None, force_decode_length=False, cache=None): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: encoder_output: Output from encoder. encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids, step, cache)` to symbol logits. hparams: run hyperparameters decode_length: an integer. How many additional timesteps to decode. vocab_size: Output vocabulary size. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. eos_id: End-of-sequence symbol in beam search. batch_size: an integer scalar - must be passed if there is no input force_decode_length: bool, whether to force the full decode length, or if False, stop when all beams hit eos_id. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if top_beams == 1 or [batch_size, top_beams, <= decode_length] otherwise "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If beam size > 1 with partial targets. """ if beam_size > 1: # Beam Search initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores, cache = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, eos_id=eos_id, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] scores = scores[:, 0] else: decoded_ids = decoded_ids[:, :top_beams, 1:] scores = scores[:, :top_beams] else: # Greedy def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob): """One step of greedy decoding.""" logits, cache = symbols_to_logits_fn(next_id, i, cache) log_probs = common_layers.log_prob_from_logits(logits) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature( logits, temperature) hit_eos |= tf.equal(next_id, eos_id) log_prob_indices = tf.stack( [tf.range(tf.to_int64(batch_size)), next_id], axis=1) log_prob += tf.gather_nd(log_probs, log_prob_indices) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob def is_not_finished(i, hit_eos, *_): finished = i >= decode_length if not force_decode_length: finished |= tf.reduce_all(hit_eos) return tf.logical_not(finished) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) hit_eos = tf.fill([batch_size], False) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) initial_log_prob = tf.zeros([batch_size], dtype=tf.float32) _, _, _, decoded_ids, cache, log_prob = tf.while_loop( is_not_finished, inner_loop, [ tf.constant(0), hit_eos, next_id, decoded_ids, cache, initial_log_prob ], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), tf.contrib.framework.nest.map_structure( beam_search.get_state_shape_invariants, cache), tf.TensorShape([None]), ]) scores = log_prob cache["outputs"] = decoded_ids cache["scores"] = scores return cache
def fast_decode( encoder_output, encoder_decoder_attention_bias, symbols_to_logits_fn, hparams, decode_length, vocab_size, init_cache_fn=_init_transformer_cache, beam_size=1, top_beams=1, alpha=1.0, sos_id=0, eos_id=beam_search.EOS_ID, batch_size=None, force_decode_length=False, scope_prefix='body/', sampling_temperature=0.0, top_k=-1, cache=None, ): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: encoder_output: Output from encoder. encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids, step, cache)` to symbol logits. hparams: run hyperparameters decode_length: an integer. How many additional timesteps to decode. vocab_size: Output vocabulary size. init_cache_fn: Function that returns the initial cache dict. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. sos_id: End-of-sequence symbol in beam search. eos_id: End-of-sequence symbol in beam search. batch_size: an integer scalar - must be passed if there is no input force_decode_length: bool, whether to force the full decode length, or if False, stop when all beams hit eos_id. scope_prefix: str, prefix for decoder layer variable scopes. sampling_temperature: scalar, temperature with which to sample. top_k: scalar, sample only top k. cache: cache dictionary for additional predictions. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if top_beams == 1 or [batch_size, top_beams, <= decode_length] otherwise "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } """ if encoder_output is not None: batch_size = common_layers.shape_list(encoder_output)[0] cache = init_cache_fn( cache=cache, hparams=hparams, batch_size=batch_size, attention_init_length=0, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, scope_prefix=scope_prefix, ) if beam_size > 1: # Beam Search initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32) decoded_ids, scores, cache = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, eos_id=eos_id, stop_early=(top_beams == 1), ) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] scores = scores[:, 0] else: decoded_ids = decoded_ids[:, :top_beams, 1:] scores = scores[:, :top_beams] else: def inner_loop( i, hit_eos, next_id, next_id_tag, decoded_ids, decoded_ids_tag, cache, log_prob, ): """One step of greedy decoding.""" logits, logits_tag, cache = symbols_to_logits_fn( next_id, next_id_tag, i, cache) log_probs = common_layers.log_prob_from_logits(logits) temperature = sampling_temperature if hparams.sampling_method == 'random_per_example': next_id = common_layers.sample_temperature_per_example( logits, temperature, top_k) else: if hparams.sampling_method == 'argmax': temperature = 0.0 next_id = common_layers.sample_with_temperature( logits, temperature, top_k) if hparams.sampling_method == 'random_per_example': next_id_tag = common_layers.sample_temperature_per_example( logits_tag, temperature, top_k) else: if hparams.sampling_method == 'argmax': temperature = 0.0 next_id_tag = common_layers.sample_with_temperature( logits_tag, temperature, top_k) log_prob_indices = tf.stack( [tf.range(tf.to_int64(batch_size)), next_id], axis=1) log_prob += tf.gather_nd( log_probs, log_prob_indices) * (1 - tf.to_float(hit_eos)) hit_eos |= tf.equal(next_id, eos_id) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) next_id_tag = tf.expand_dims(next_id_tag, axis=1) decoded_ids_tag = tf.concat([decoded_ids_tag, next_id_tag], axis=1) return ( i + 1, hit_eos, next_id, next_id_tag, decoded_ids, decoded_ids_tag, cache, log_prob, ) def is_not_finished(i, hit_eos, *_): finished = i >= decode_length if not force_decode_length: finished |= tf.reduce_all(hit_eos) return tf.logical_not(finished) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) decoded_ids_tag = tf.zeros([batch_size, 0], dtype=tf.int64) hit_eos = tf.fill([batch_size], False) next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64) next_id_tag = sos_id * tf.ones([batch_size, 1], dtype=tf.int64) initial_log_prob = tf.zeros([batch_size], dtype=tf.float32) _, _, _, _, decoded_ids, decoded_ids_tag, cache, log_prob = tf.while_loop( is_not_finished, inner_loop, [ tf.constant(0), hit_eos, next_id, next_id_tag, decoded_ids, decoded_ids_tag, cache, initial_log_prob, ], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(beam_search.get_state_shape_invariants, cache), tf.TensorShape([None]), ], ) scores = log_prob return { 'outputs': decoded_ids, 'outputs_tag': decoded_ids_tag, 'scores': scores, 'cache': cache, }
def _fast_decode(self, features, decode_length, last_position_only=True, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. last_position_only: MUST be true for fast decoding! beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: ValueError: If last_position_only if False NotImplementedError: If there are multiple data shards. """ if not last_position_only: raise ValueError( "Fast decoding only deals with the last positions!") if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = tf.shape(inputs)[0] target_modality = self._problem_hparams.target_modality if t2t_model.is_class_modality(target_modality): decode_length = 1 else: decode_length = tf.shape(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = tf.shape(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access for layer in cache: cache[layer]["k"]._shape = tf.TensorShape( [None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape( [None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, _ = beam_search.beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids
def fast_decode(wav_encoder_output, txt_encoder_output, wav_enc_dec_attention_bias, txt_enc_dec_attention_bias, symbols_to_logits_fn, hparams, decode_length, vocab_size, beam_size=1, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID, batch_size=None, force_decode_length=False): """ implement greedy and beam search Args: wav_encoder_output: Output from wav encoder. txt_encoder_output: Output from txt encoder. wav_enc_dec_attention_bias: a bias tensor for use in enc-dec attention over wav inputs txt_enc_dec_attention_bias: a bias tensor for use in enc-dec attention over txt inputs symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids, step, cache)` to symbol logits. hparams: run hyperparameters decode_length: an integer. How many additional timesteps to decode. vocab_size: Output vocabulary size. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. eos_id: End-of-sequence symbol in beam search. batch_size: an integer scalar - must be passed if there is no input force_decode_length: bool, whether to force the full decode length, or if False, stop when all beams hit eos_id. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if top_beams == 1 or [batch_size, top_beams, <= decode_length] otherwise "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If beam size > 1 with partial targets. """ if wav_encoder_output is not None: batch_size = common_layers.shape_list(wav_encoder_output)[0] key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), "f": tf.zeros([batch_size, 0, hparams.hidden_size]), } for layer in range(num_layers) } if txt_encoder_output and wav_encoder_output: cache["wav_enc_output"] = wav_encoder_output cache["txt_enc_output"] = txt_encoder_output cache["wav_enc_dec_attention_bias"] = wav_enc_dec_attention_bias cache["txt_enc_dec_attention_bias"] = txt_enc_dec_attention_bias if beam_size > 1: # Beam Search initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, eos_id=eos_id, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] scores = scores[:, 0] else: decoded_ids = decoded_ids[:, :top_beams, 1:] scores = scores[:, :top_beams] else: # Greedy search # pass def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob): """One step of greedy decoding.""" logits, cache = symbols_to_logits_fn(next_id, i, cache) log_probs = common_layers.log_prob_from_logits(logits) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature(logits, temperature) hit_eos |= tf.equal(next_id, eos_id) log_prob_indices = tf.stack( [tf.range(tf.to_int64(batch_size)), next_id], axis=1) log_prob += tf.gather_nd(log_probs, log_prob_indices) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob def is_not_finished(i, hit_eos, *_): finished = i >= decode_length if not force_decode_length: finished |= tf.reduce_all(hit_eos) return tf.logical_not(finished) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) hit_eos = tf.fill([batch_size], False) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) initial_log_prob = tf.zeros([batch_size], dtype=tf.float32) _, _, _, decoded_ids, _, log_prob = tf.while_loop( is_not_finished, inner_loop, [ tf.constant(0), hit_eos, next_id, decoded_ids, cache, initial_log_prob ], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(beam_search.get_state_shape_invariants, cache), tf.TensorShape([None]), ]) scores = log_prob return {"outputs": decoded_ids, "scores": scores}
def fast_decode(encoder_output, encoder_decoder_attention_bias, symbols_to_logits_fn, hparams, decode_length, vocab_size, beam_size=1, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: encoder_output: Output from encoder. encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids, step, cache)` to symbol logits. hparams: run hyperparameters decode_length: an integer. How many additional timesteps to decode. vocab_size: Output vocabulary size. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. eos_id: End-of-sequence symbol in beam search. Returns: Pair of tensors `(decoded_ids, scores)`, where `decoded_ids` is a 2-d or 3-d (when doing beam search with top_beams > 1) tensor containing result of decoding, and `scores` is the beam search scores. """ batch_size = common_layers.shape_list(encoder_output)[0] key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access if not context.in_eager_mode(): for layer in cache: cache[layer]["k"]._shape = tf.TensorShape( [None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape( [None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, eos_id=eos_id, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims(common_layers.sample_with_temperature( logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) scores = None next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """ Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = common_layers.shape_list(inputs)[0] target_modality = self._problem_hparams.target_modality if target_modality.is_class_modality: decode_length = 1 else: decode_length = common_layers.shape_list(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache, nonpadding=transformer._features_to_nonpadding( features, "targets")) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access if not context.in_eager_mode(): for layer in cache: cache[layer]["k"]._shape = tf.TensorShape( [None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape( [None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, stop_early=(top_beams == 1)) decoded_ids = decoded_ids[:, :, 1:] # do roulette wheel selection or inverse roulette wheel selection if self._hparams.roulette == "Normal" or self._hparams.roulette == "Inverse": if self._hparams.roulette == "Normal": probabilities = tf.pow(tf.constant(2.0), scores) start = 0 else: probabilities = tf.subtract( tf.constant(1.0), tf.pow(tf.constant(2.0), scores)) start = beam_size - self._hparams.roulette_beam_size summ = tf.reduce_sum(probabilities) ex_probs = tf.divide(probabilities, summ) #ex_probs=tf.nn.softmax(probabilities) # sample a number between 0 and 1 wheel = tf.random_uniform([1]) upper_bound = tf.constant(0.0) # change this as well if using inverse for i in range(start, self._hparams.roulette_beam_size): upper_bound = tf.add(ex_probs[:, i], upper_bound) truthValue = tf.squeeze( tf.logical_and(wheel >= upper_bound - ex_probs[:, i], wheel <= upper_bound)) decoded_ids, scores, i = tf.cond( truthValue, lambda: (decoded_ids[:, i, :], scores[:, i], beam_size), lambda: (decoded_ids, scores, i)) else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims(common_layers.sample_with_temperature( logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) scores = None next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores
def _beam_decode(self, features, decode_length, beam_size, top_beams, last_position_only, alpha): """Beam search decoding. Args: features: an map of string to `Tensor` decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. last_position_only: a boolean, speed-up by computing last position only. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search """ def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]]) features["targets"] = ids self._coverage = None sharded_logits, _, _ = self.model_fn( features, False, last_position_only=last_position_only) # now self._coverage is a coverage tensor for the first datashard. # it has shape [batch_size] and contains floats between 0 and # source_length. logits = sharded_logits[0] # Assuming we have one shard. if last_position_only: return tf.squeeze(logits, axis=[1, 2, 3]) current_output_position = tf.shape(ids)[1] - 1 # -1 due to the pad above. logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1, 2]) batch_size = tf.shape(features["inputs"])[0] initial_ids = tf.zeros([batch_size], dtype=tf.int32) inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 1) if len(features["inputs"].shape) < 5: features["inputs"] = tf.expand_dims(features["inputs"], 4) # Expand the inputs in to the beam size. features["inputs"] = tf.tile(features["inputs"], [1, beam_size, 1, 1, 1]) s = tf.shape(features["inputs"]) features["inputs"] = tf.reshape(features["inputs"], [s[0] * s[1], s[2], s[3], s[4]]) target_modality = self._hparams.problems[self._problem_idx].target_modality vocab_size = target_modality.top_dimensionality # Setting decode length to input length + decode_length decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length) ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha) # Set inputs back to the unexpanded inputs to not to confuse the Estimator! features["inputs"] = inputs_old # Return `top_beams` decodings (also remove initial id from the beam search) return_scores = False # TODO(lukaszkaiser): make it work multi-problem. if top_beams == 1: if return_scores: return {"outputs": ids[:, 0, 1:], "scores": scores} return ids[:, 0, 1:] else: if return_scores: return {"outputs": ids[:, :top_beams, 1:], "scores": scores} return ids[:, :top_beams, 1:]