Example #1
0
 def step(self, time, inputs, state, name=None):
     """Perform a decoding step.
     Args:
     time: scalar `int32` tensor.
     inputs: A (structure of) input tensors.
     state: A (structure of) state tensors and TensorArrays.
     name: Name scope for any created operations.
     Returns:
     `(outputs, next_state, next_inputs, finished)`.
     """
     with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
         cell_outputs, cell_state = self._cell(inputs, state)
         #cell_outputs =
         cell_dist_outputs = self.__output_layer(
             cell_outputs) if self.__output_layer else cell_outputs
         sample_ids = self._helper.sample(time=time,
                                          outputs=cell_dist_outputs,
                                          state=cell_state)
         (finished, next_inputs, next_state) = self._helper.next_inputs(
             time=time,
             outputs=cell_outputs,
             dist_outputs=cell_dist_outputs,
             state=cell_state,
             sample_ids=sample_ids)
     outputs = BasicDecoderOutput(cell_outputs, sample_ids)
     return (outputs, next_state, next_inputs, finished)
Example #2
0
 def output_dtype(self):
     # Assume the dtype of the cell is the output_size structure
     # containing the input_state's first component's dtype.
     # Return that structure and the sample_ids_dtype from the helper.
     dtype = nest.flatten(self._initial_state)[0].dtype
     return BasicDecoderOutput(
         nest.map_structure(lambda _: dtype, self._rnn_output_size()),
         self._helper.sample_ids_dtype)
Example #3
0
    def step(self, time, inputs, state, name=None):
        with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
            cell_outputs, cell_state = self._cell(inputs, state)
            # cell_outputs.set_shape((self.params.batch_size, 4 * self.params.units))

            cell_outputs = self._readout(inputs, cell_outputs, state.attention)

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)
            sample_ids = self._helper.sample(time=time,
                                             outputs=cell_outputs,
                                             state=cell_state)
            (finished, next_inputs,
             next_state) = self._helper.next_inputs(time=time,
                                                    outputs=cell_outputs,
                                                    state=cell_state,
                                                    sample_ids=sample_ids)
        outputs = BasicDecoderOutput(cell_outputs, sample_ids)
        return outputs, next_state, next_inputs, finished
Example #4
0
    def step(self, time, inputs, state, name=None):
        with tf.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
            if self._context is not None:
                inputs = tf.concat([inputs, self._context], axis=-1)

            cell_outputs, cell_state = self._cell(inputs, state)
            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)
            sample_ids = self._helper.sample(time=time,
                                             outputs=cell_outputs,
                                             state=cell_state)
            (finished, next_inputs,
             next_state) = self._helper.next_inputs(time=time,
                                                    outputs=cell_outputs,
                                                    state=cell_state,
                                                    sample_ids=sample_ids)

        outputs = BasicDecoderOutput(cell_outputs, sample_ids)
        return (outputs, next_state, next_inputs, finished)
Example #5
0
 def step(self, time, inputs, state, name=None):
     with tf.name_scope(name, "GrammarDecodingStep", (time, inputs, state)):
         decoder_state, grammar_state = state
         cell_outputs, cell_state = self._cell(inputs, decoder_state)
         if self._output_layer is not None:
             cell_outputs = self._output_layer(cell_outputs)
         grammar_cell_outputs = self._grammar_helper.constrain_logits(cell_outputs, grammar_state)
         cell_outputs = grammar_cell_outputs
         sample_ids = self._helper.sample(time=time, outputs=grammar_cell_outputs, state=cell_state)
         (finished, next_inputs, next_decoder_state) = self._helper.next_inputs(
             time=time,
             outputs=cell_outputs,
             state=cell_state,
             sample_ids=sample_ids)
         if self._fixed_outputs is not None:
             next_grammar_state = self._grammar_helper.transition(grammar_state, self._fixed_outputs.read(time), self.batch_size)
         else:
             next_grammar_state = self._grammar_helper.transition(grammar_state, sample_ids, self.batch_size)
         next_state = (next_decoder_state, next_grammar_state)
     outputs = BasicDecoderOutput(cell_outputs, sample_ids)
     return (outputs, next_state, next_inputs, finished)
Example #6
0
    def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
        """Internal while_loop body.

        Args:
          time: scalar int32 tensor.
          outputs_ta: structure of TensorArray.
          state: (structure of) state tensors and TensorArrays. list
          inputs: (structure of) input tensors. list
          finished: bool tensor (keeping track of what's finished). list
          sequence_lengths: int32 tensor (keeping track of time of finish).

        Returns:
          `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
            next_sequence_lengths)`.
          ```
        """
        decoders_next_outputs = []
        decoders_next_states = []
        decoders_next_inputs = []
        decoders_next_finished = []
        decoders_next_seqlen = []
        decoders_next_ta = []

        outputs_collection = []

        decoder_cnt = 0
        for decoder in decoders:
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs[decoder_cnt],
                                              state[decoder_cnt])
            next_finished = math_ops.logical_or(decoder_finished,
                                                finished[decoder_cnt])
            if maximum_iterations is not None:
                next_finished = math_ops.logical_or(
                    next_finished, time + 1 >= maximum_iterations)

            nest.assert_same_structure(state[decoder_cnt], decoder_state)
            nest.assert_same_structure(outputs_ta[decoder_cnt], next_outputs)
            nest.assert_same_structure(inputs[decoder_cnt], next_inputs)
            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished[decoder_cnt],
                                                      zero, out), next_outputs,
                    decoders_zero_outputs[decoder_cnt])
            else:
                emit = next_outputs

            outputs_collection.append(local2global(next_outputs, decoder_cnt))

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished[decoder_cnt], cur, new)

            next_state = None
            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state,
                                                state[decoder_cnt])
            else:
                next_state = decoder_state
            this_outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta[decoder_cnt],
                emit)

            next_sequence_lengths = array_ops.where(
                math_ops.logical_and(
                    math_ops.logical_not(finished[decoder_cnt]),
                    next_finished),
                array_ops.fill(array_ops.shape(sequence_lengths[decoder_cnt]),
                               time + 1), sequence_lengths[decoder_cnt])

            decoders_next_inputs.append(next_inputs)
            decoders_next_outputs.append(next_outputs)
            decoders_next_states.append(next_state)
            decoders_next_finished.append(next_finished)
            decoders_next_seqlen.append(next_sequence_lengths)
            decoders_next_ta.append(this_outputs_ta)
            decoder_cnt += 1

        ma_weights = tf.nn.softmax(tf.matmul(outputs_collection[0], ma_policy),
                                   -1)
        print('ma_weights_shape:', ma_weights)
        if policy_mode == 'FULL':
            outputs_collection = outputs_collection[1:]
        outputs_collection = tf.transpose(
            ops.convert_to_tensor(outputs_collection, dtype=dtypes.float32),
            [2, 1, 0])
        print('all_outputs_shape:', outputs_collection)
        final_outputs = tf.transpose(
            tf.reduce_sum(outputs_collection * ma_weights, -1), [1, 0])
        # final_outputs=tf.transpose(outputs_collection, [2,1,0])[0]
        print('final_outputs_shape:', final_outputs)
        sample_ids = math_ops.cast(math_ops.argmax(final_outputs, axis=-1),
                                   dtypes.int32)

        wrapped_final_outputs = BasicDecoderOutput(final_outputs, sample_ids)
        decoders_next_ta.append(
            nest.map_structure(lambda ta, out: ta.write(time, out),
                               outputs_ta[-2], wrapped_final_outputs))
        decoders_next_ta.append(
            nest.map_structure(lambda ta, out: ta.write(time, out),
                               outputs_ta[-1], ma_weights))

        for dno in range(len(decoders)):
            decoders_next_inputs[dno] = global2local(sample_ids, dno)

        outputs_ta = decoders_next_ta
        next_inputs = decoders_next_inputs
        next_state = decoders_next_states
        next_finished = decoders_next_finished
        next_seqlen = decoders_next_seqlen

        return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
                next_seqlen)
Example #7
0
 def output_size(self):
     # Return the cell output and the id
     return BasicDecoderOutput(rnn_output=self._rnn_output_size(),
                               sample_id=self._helper.sample_ids_shape)
Example #8
0
 def output_size(self):
     # Return the cell output and the id
     return BasicDecoderOutput(rnn_output=self._rnn_output_size(),
                               sample_id=tensor_shape.TensorShape([]))