コード例 #1
0
        def body(time, outputs_ta, parents):
            # get ids, logits and parents predicted at time step by decoder
            input_t = nest.map_structure(lambda t: t[time], final_outputs)

            # extract the entries corresponding to parents
            new_state = nest.map_structure(
                lambda t: gather_helper(t, parents, self._batch_size, self.
                                        _beam_size), input_t)

            # create new output
            new_output = DecoderOutput(logits=new_state.logits,
                                       ids=new_state.ids)

            # write beam ids
            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, new_output)

            return (time + 1), outputs_ta, parents
コード例 #2
0
 def final_output_dtype(self):
     """For the finalize method"""
     return DecoderOutput(logits=self._cell.output_dtype, ids=tf.int32)
コード例 #3
0
    def finalize(self, final_outputs, final_state):
        """
        Args:
            final_outputs: structure of tensors of shape
                    [time dimension, batch_size, beam_size, d]
            final_state: instance of BeamSearchDecoderOutput

        Returns:
            [time, batch, beam, ...] structure of Tensor

        """
        # reverse the time dimension
        maximum_iterations = tf.shape(final_outputs.ids)[0]
        final_outputs = nest.map_structure(lambda t: tf.reverse(t, axis=[0]),
                                           final_outputs)

        # initial states
        def create_ta(d):
            return tf.TensorArray(dtype=d, size=maximum_iterations)

        initial_time = tf.constant(0, dtype=tf.int32)
        initial_outputs_ta = nest.map_structure(create_ta,
                                                self.final_output_dtype)
        initial_parents = tf.tile(tf.expand_dims(tf.range(self._beam_size),
                                                 axis=0),
                                  multiples=[self._batch_size, 1])

        def condition(time, outputs_ta, parents):
            return tf.less(time, maximum_iterations)

        # beam search decoding cell
        def body(time, outputs_ta, parents):
            # get ids, logits and parents predicted at time step by decoder
            input_t = nest.map_structure(lambda t: t[time], final_outputs)

            # extract the entries corresponding to parents
            new_state = nest.map_structure(
                lambda t: gather_helper(t, parents, self._batch_size, self.
                                        _beam_size), input_t)

            # create new output
            new_output = DecoderOutput(logits=new_state.logits,
                                       ids=new_state.ids)

            # write beam ids
            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, new_output)

            return (time + 1), outputs_ta, parents

        res = tf.while_loop(
            condition,
            body,
            loop_vars=[initial_time, initial_outputs_ta, initial_parents],
            back_prop=False)

        # unfold and stack the structure from the nested tas
        final_outputs = nest.map_structure(lambda ta: ta.stack(), res[1])

        # reverse time step
        final_outputs = nest.map_structure(lambda t: tf.reverse(t, axis=[0]),
                                           final_outputs)

        return DecoderOutput(logits=final_outputs.logits,
                             ids=final_outputs.ids)
コード例 #4
0
 def final_output_size(self):
     return DecoderOutput(logits=tf.TensorShape(
         [self._beam_size, self._vocab_size]),
                          ids=tf.TensorShape([self._beam_size]))