def initialize(self, name=None): finished = tf.tile([False], [self.config.beam_width]) start_tokens_batch = tf.fill([self.config.beam_width], self.start_tokens) first_inputs = tf.nn.embedding_lookup(self.target_embedding, start_tokens_batch) first_inputs = tf.expand_dims(first_inputs, 1) zeros_padding = tf.zeros([ self.config.beam_width, self.params['max_decode_length'] - 1, self.target_embedding.get_shape().as_list()[-1] ]) first_inputs = tf.concat([first_inputs, zeros_padding], axis=1) beam_state = beam_search.create_initial_beam_state(self.config) outputs = tf.tile(self.initial_state.outputs, [self.config.beam_width, 1, 1]) attention_values = tf.tile(self.initial_state.attention_values, [self.config.beam_width, 1, 1]) enc_output = EncoderOutput( outputs=outputs, final_state=self.initial_state.final_state, attention_values=attention_values, attention_values_length=self.initial_state.attention_values_length) return finished, first_inputs, (enc_output, beam_state)
def initialize(self, name=None): finished = tf.tile([False], [self.config.beam_width]) start_tokens_batch = tf.fill([self.config.beam_width], self.start_tokens) first_inputs = tf.nn.embedding_lookup(self.target_embedding, start_tokens_batch) first_inputs = tf.expand_dims(first_inputs, 1) zeros_padding = tf.zeros([self.config.beam_width, self.params['max_decode_length']-1, self.target_embedding.get_shape().as_list()[-1]]) first_inputs = tf.concat([first_inputs, zeros_padding], axis=1) beam_state = beam_search.create_initial_beam_state(self.config) outputs = tf.tile(self.initial_state.outputs, [self.config.beam_width,1,1]) attention_values = tf.tile(self.initial_state.attention_values, [self.config.beam_width,1,1]) enc_output = EncoderOutput( outputs=outputs, final_state=self.initial_state.final_state, attention_values=attention_values, attention_values_length=self.initial_state.attention_values_length) return finished, first_inputs, (enc_output, beam_state)
def initialize(self, name=None): finished, first_inputs, initial_state = self.decoder.initialize() # Create beam state beam_state = beam_search.create_initial_beam_state(config=self.config) return finished, first_inputs, (initial_state, beam_state)
def step(self, time_, cell_output, cell_state, loop_state): initial_call = (cell_output is None) if initial_call: cell_output = tf.zeros( [self.config.beam_width, self.cell.output_size]) # We start out with all beams being equal, so we tile the cell state # [beam_width] times next_cell_state = beam_search.nest_map( cell_state, lambda x: tf.tile(x, [self.config.beam_width, 1])) # Call the original decoder original_outputs = self.decoder.step(time_, None, cell_state, loop_state) # Create an initial Beam State beam_state = beam_search.create_initial_beam_state( config=self.config, max_time=self.decoder.max_decode_length) next_loop_state = self._wrap_loop_state( beam_state, original_outputs.next_loop_state) outputs = self.output_shapes() else: prev_beam_state, original_loop_state = self._unwrap_loop_state( loop_state) # Call the original decoder original_outputs = self.decoder.step(time_, cell_output, cell_state, original_loop_state) # Perform a step of beam search beam_state = beam_search.beam_search_step( logits=original_outputs.outputs.logits, beam_state=prev_beam_state, config=self.config) beam_state.predicted_ids.set_shape( [None, self.decoder.max_decode_length]) next_loop_state = self._wrap_loop_state( beam_state, original_outputs.next_loop_state) outputs = BeamDecoderOutput( logits=tf.zeros( [self.config.beam_width, self.config.vocab_size]), predicted_ids=tf.to_int64(beam_state.predicted_ids[:, time_ - 1]), log_probs=beam_state.log_probs, scores=beam_state.scores, beam_parent_ids=beam_state.beam_parent_ids, original_outputs=original_outputs.outputs) # Cell states are shuffled around by beam search next_cell_state = beam_search.nest_map( original_outputs.next_cell_state, lambda x: tf.gather(x, beam_state.beam_parent_ids)) # The final step output step_output = DecoderStepOutput(outputs=outputs, next_cell_state=next_cell_state, next_loop_state=next_loop_state) return step_output