def setUp(self): super(TestBeamStep, self).setUp() self.state_size = 10 config = beam_search.BeamSearchConfig( beam_width=3, vocab_size=5, eos_token=0, length_penalty_weight=0.6, choose_successors_fn=beam_search.choose_top_k) self.config = config
def setUp(self): super(TestBeamStep, self).setUp() self.state_size = 10 config = beam_search.BeamSearchConfig( beam_width=3, vocab_size=5, eos_token=0, score_fn=beam_search.logprob_score, choose_successors_fn=beam_search.choose_top_k) self.config = config
def _get_beam_search_decoder(self, decoder): # Create configuration of beam search config = beam_search.BeamSearchConfig( beam_width=self.params["inference.beam_search.beam_width"], vocab_size=self.target_vocab_info.total_size, eos_token=self.target_vocab_info.special_vocab.SEQUENCE_END, length_penalty_weight=self. params["inference.beam_search.length_penalty_weight"], choose_successors_fn=getattr( beam_search, self.params["inference.beam_search.choose_successors_fn"])) # ToDo: Special Beam Search Decoder for Ensemble return EnsembleBeamSearchDecoder(decoder, config)
def test_with_beam_search(self): self.batch_size = 1 # Batch size for beam search must be 1. config = beam_search.BeamSearchConfig( beam_width=10, vocab_size=self.vocab_size, eos_token=self.vocab_size - 2, length_penalty_weight=0.6, choose_successors_fn=beam_search.choose_top_k) embeddings = tf.get_variable("W_embed", [self.vocab_size, self.input_depth]) helper = decode_helper.GreedyEmbeddingHelper(embedding=embeddings, start_tokens=[0] * config.beam_width, end_token=-1) decoder_fn = self.create_decoder(helper=helper, mode=tf.contrib.learn.ModeKeys.INFER) decoder_fn = beam_search_decoder.BeamSearchDecoder(decoder=decoder_fn, config=config) initial_state = decoder_fn.cell.zero_state(self.batch_size, dtype=tf.float32) decoder_output, _ = decoder_fn(initial_state, helper) #pylint: disable=E1101 with self.test_session() as sess: sess.run(tf.global_variables_initializer()) decoder_output_ = sess.run(decoder_output) np.testing.assert_array_equal( decoder_output_.predicted_ids.shape, [self.max_decode_length, 1, config.beam_width]) np.testing.assert_array_equal( decoder_output_.beam_search_output.beam_parent_ids.shape, [self.max_decode_length, 1, config.beam_width]) np.testing.assert_array_equal( decoder_output_.beam_search_output.scores.shape, [self.max_decode_length, 1, config.beam_width]) np.testing.assert_array_equal( decoder_output_.beam_search_output.original_outputs.predicted_ids. shape, [self.max_decode_length, 1, config.beam_width]) np.testing.assert_array_equal( decoder_output_.beam_search_output.original_outputs.logits.shape, [self.max_decode_length, 1, config.beam_width, self.vocab_size]) return decoder_output
def _get_beam_search_decoder(self, decoder): """Wraps a decoder into a Beam Search decoder. Args: decoder: The original decoder Returns: A BeamSearchDecoder with the same interfaces as the original decoder. """ config = beam_search.BeamSearchConfig( beam_width=self.params["inference.beam_search.beam_width"], vocab_size=self.target_vocab_info.total_size, eos_token=self.target_vocab_info.special_vocab.SEQUENCE_END, score_fn=getattr(beam_search, self.params["inference.beam_search.score_fn"]), choose_successors_fn=getattr( beam_search, self.params["inference.beam_search.choose_successors_fn"])) return BeamSearchDecoder(decoder=decoder, config=config)
def _create_decoder(self, encoder_output, features, _labels): config = beam_search.BeamSearchConfig( beam_width=self.params["inference.beam_search.beam_width"], vocab_size=self.target_vocab_info.total_size, eos_token=self.target_vocab_info.special_vocab.SEQUENCE_END, length_penalty_weight=self. params["inference.beam_search.length_penalty_weight"], choose_successors_fn=getattr( beam_search, self.params["inference.beam_search.choose_successors_fn"])) return self.decoder_class( params=self.params["decoder.params"], mode=self.mode, vocab_size=self.target_vocab_info.total_size, config=config, target_embedding=self.target_embedding_fairseq(), pos_embedding=self.target_pos_embedding_fairseq(), start_tokens=self.target_vocab_info.special_vocab.SEQUENCE_END)