def body(time, outputs_ta, parents): # get ids, logits and parents predicted at time step by decoder input_t = nest.map_structure(lambda t: t[time], final_outputs) # extract the entries corresponding to parents new_state = nest.map_structure( lambda t: gather_helper(t, parents, self._batch_size, self. _beam_size), input_t) # create new output new_output = DecoderOutput(logits=new_state.logits, ids=new_state.ids) # write beam ids outputs_ta = nest.map_structure( lambda ta, out: ta.write(time, out), outputs_ta, new_output) return (time + 1), outputs_ta, parents
def final_output_dtype(self): """For the finalize method""" return DecoderOutput(logits=self._cell.output_dtype, ids=tf.int32)
def finalize(self, final_outputs, final_state): """ Args: final_outputs: structure of tensors of shape [time dimension, batch_size, beam_size, d] final_state: instance of BeamSearchDecoderOutput Returns: [time, batch, beam, ...] structure of Tensor """ # reverse the time dimension maximum_iterations = tf.shape(final_outputs.ids)[0] final_outputs = nest.map_structure(lambda t: tf.reverse(t, axis=[0]), final_outputs) # initial states def create_ta(d): return tf.TensorArray(dtype=d, size=maximum_iterations) initial_time = tf.constant(0, dtype=tf.int32) initial_outputs_ta = nest.map_structure(create_ta, self.final_output_dtype) initial_parents = tf.tile(tf.expand_dims(tf.range(self._beam_size), axis=0), multiples=[self._batch_size, 1]) def condition(time, outputs_ta, parents): return tf.less(time, maximum_iterations) # beam search decoding cell def body(time, outputs_ta, parents): # get ids, logits and parents predicted at time step by decoder input_t = nest.map_structure(lambda t: t[time], final_outputs) # extract the entries corresponding to parents new_state = nest.map_structure( lambda t: gather_helper(t, parents, self._batch_size, self. _beam_size), input_t) # create new output new_output = DecoderOutput(logits=new_state.logits, ids=new_state.ids) # write beam ids outputs_ta = nest.map_structure( lambda ta, out: ta.write(time, out), outputs_ta, new_output) return (time + 1), outputs_ta, parents res = tf.while_loop( condition, body, loop_vars=[initial_time, initial_outputs_ta, initial_parents], back_prop=False) # unfold and stack the structure from the nested tas final_outputs = nest.map_structure(lambda ta: ta.stack(), res[1]) # reverse time step final_outputs = nest.map_structure(lambda t: tf.reverse(t, axis=[0]), final_outputs) return DecoderOutput(logits=final_outputs.logits, ids=final_outputs.ids)
def final_output_size(self): return DecoderOutput(logits=tf.TensorShape( [self._beam_size, self._vocab_size]), ids=tf.TensorShape([self._beam_size]))