def test_with_dynamic_inputs(self): embeddings = tf.get_variable("W_embed", [self.vocab_size, self.input_depth]) helper = decode_helper.GreedyEmbeddingHelper(embedding=embeddings, start_tokens=[0] * self.batch_size, end_token=-1) decoder_fn = self.create_decoder(helper=helper, mode=tf.contrib.learn.ModeKeys.INFER) initial_state = decoder_fn.cell.zero_state(self.batch_size, dtype=tf.float32) decoder_output, _ = decoder_fn(initial_state, helper) #pylint: disable=E1101 with self.test_session() as sess: sess.run(tf.global_variables_initializer()) decoder_output_ = sess.run(decoder_output) np.testing.assert_array_equal( decoder_output_.logits.shape, [self.max_decode_length, self.batch_size, self.vocab_size]) np.testing.assert_array_equal( decoder_output_.predicted_ids.shape, [self.max_decode_length, self.batch_size])
def _decode_infer(self, decoder, bridge, _encoder_output, features, labels): """Runs decoding in inference mode""" batch_size = self.batch_size(features, labels) if self.use_beam_search: batch_size = self.params["inference.beam_search.beam_width"] target_start_id = self.target_vocab_info.special_vocab.SEQUENCE_START helper_infer = tf_decode_helper.GreedyEmbeddingHelper( embedding=self.target_embedding, start_tokens=tf.fill([batch_size], target_start_id), end_token=self.target_vocab_info.special_vocab.SEQUENCE_END) decoder_initial_state = bridge() return decoder(decoder_initial_state, helper_infer)
def test_with_beam_search(self): self.batch_size = 1 # Batch size for beam search must be 1. config = beam_search.BeamSearchConfig( beam_width=10, vocab_size=self.vocab_size, eos_token=self.vocab_size - 2, length_penalty_weight=0.6, choose_successors_fn=beam_search.choose_top_k) embeddings = tf.get_variable("W_embed", [self.vocab_size, self.input_depth]) helper = decode_helper.GreedyEmbeddingHelper(embedding=embeddings, start_tokens=[0] * config.beam_width, end_token=-1) decoder_fn = self.create_decoder(helper=helper, mode=tf.contrib.learn.ModeKeys.INFER) decoder_fn = beam_search_decoder.BeamSearchDecoder(decoder=decoder_fn, config=config) initial_state = decoder_fn.cell.zero_state(self.batch_size, dtype=tf.float32) decoder_output, _ = decoder_fn(initial_state, helper) #pylint: disable=E1101 with self.test_session() as sess: sess.run(tf.global_variables_initializer()) decoder_output_ = sess.run(decoder_output) np.testing.assert_array_equal( decoder_output_.predicted_ids.shape, [self.max_decode_length, 1, config.beam_width]) np.testing.assert_array_equal( decoder_output_.beam_search_output.beam_parent_ids.shape, [self.max_decode_length, 1, config.beam_width]) np.testing.assert_array_equal( decoder_output_.beam_search_output.scores.shape, [self.max_decode_length, 1, config.beam_width]) np.testing.assert_array_equal( decoder_output_.beam_search_output.original_outputs.predicted_ids. shape, [self.max_decode_length, 1, config.beam_width]) np.testing.assert_array_equal( decoder_output_.beam_search_output.original_outputs.logits.shape, [self.max_decode_length, 1, config.beam_width, self.vocab_size]) return decoder_output