示例#1
0
 def output_dtype(self):
     # Assume the dtype of the cell is the output_size structure
     # containing the input_state's first component's dtype.
     # Return that structure and the sample_ids_dtype from the helper.
     dtype = nest.flatten(self._initial_state)[0].dtype
     if self.extract_state:
         return BasicDecoderWithStateOutput(dtype, dtype,
                                            self._helper.sample_ids_dtype)
     return seq2seq.BasicDecoderOutput(dtype, self._helper.sample_ids_dtype)
示例#2
0
 def output_size(self):
     # Return the cell output and the id
     if self.extract_state:
         return BasicDecoderWithStateOutput(
             rnn_output=self._rnn_output_size(),
             rnn_state=tensor_shape.TensorShape([self._cell.output_size]),
             sample_id=self._helper.sample_ids_shape)
     return seq2seq.BasicDecoderOutput(
         rnn_output=self._rnn_output_size(),
         sample_id=self._helper.sample_ids_shape)
示例#3
0
    def step(self, time, inputs, state, name=None):
        """ Performs a decoding step
            :param time: scalar `int32` tensor.
            :param inputs: A (structure of) input tensors.  (** This is a MaskedInputs tuple **)
            :param state: A (structure of) state tensors and TensorArrays.
            :param name: Name scope for any created operations.
            :return: (outputs, next_state, next_inputs, finished)
        """
        assert isinstance(
            inputs,
            CandidateInputs), 'The inputs must be of type "CandidateInputs"'
        with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
            inputs, candidates, candidates_emb = inputs.inputs, inputs.candidates, inputs.candidates_emb
            cell_outputs, cell_state = self._cell(inputs, state)
            cell_state_output = cell_outputs  # Corresponds to cell_state.h (before output layer)
            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            # Adding a bias dimension, then computing candidate logits and masking PAD_IDs
            cell_outputs = array_ops.pad(cell_outputs, [(0, 0), (0, 1)],
                                         constant_values=1.)
            cell_outputs = math_ops.reduce_sum(cell_outputs[:, None, :] *
                                               candidates_emb,
                                               axis=-1)
            output_mask = math_ops.cast(gen_math_ops.greater(candidates, 0),
                                        dtypes.float32)
            cell_outputs = gen_math_ops.add(cell_outputs, (1. - output_mask) *
                                            LARGE_NEGATIVE)

            # Sampling and computing next inputs
            sample_ids = self._helper.sample(time=time,
                                             outputs=(cell_outputs,
                                                      candidates),
                                             state=cell_state)
            (finished, next_inputs,
             next_state) = self._helper.next_inputs(time=time,
                                                    outputs=cell_outputs,
                                                    state=cell_state,
                                                    sample_ids=sample_ids)
        if self.extract_state:
            outputs = BasicDecoderWithStateOutput(cell_outputs,
                                                  cell_state_output,
                                                  sample_ids)
        else:
            outputs = seq2seq.BasicDecoderOutput(cell_outputs, sample_ids)
        return outputs, next_state, next_inputs, finished
示例#4
0
 def step(self, time, inputs, state, name=None):
     """ Performs a decoding step
         :param time: scalar `int32` tensor.
         :param inputs: A (structure of) input tensors.  (** This is a MaskedInputs tuple **)
         :param state: A (structure of) state tensors and TensorArrays.
         :param name: Name scope for any created operations.
         :return: (outputs, next_state, next_inputs, finished)
     """
     assert isinstance(
         inputs,
         (MaskedInputs, ops.Tensor)), 'Expected "MaskedInputs" or a Tensor.'
     with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
         inputs, output_mask = inputs, None
         if isinstance(inputs, MaskedInputs):
             inputs, output_mask = inputs.inputs, inputs.mask
         cell_outputs, cell_state = self._cell(inputs, state)
         cell_state_output = cell_outputs  # Corresponds to cell_state.h (before output layer)
         if self._output_layer is not None:
             cell_outputs = self._output_layer(cell_outputs)
         if output_mask is not None:
             cell_outputs = gen_math_ops.add(
                 cell_outputs, (1. - output_mask) * LARGE_NEGATIVE)
         sample_ids = self._helper.sample(time=time,
                                          outputs=cell_outputs,
                                          state=cell_state)
         (finished, next_inputs,
          next_state) = self._helper.next_inputs(time=time,
                                                 outputs=cell_outputs,
                                                 state=cell_state,
                                                 sample_ids=sample_ids)
     if self.extract_state:
         outputs = BasicDecoderWithStateOutput(cell_outputs,
                                               cell_state_output,
                                               sample_ids)
     else:
         outputs = seq2seq.BasicDecoderOutput(cell_outputs, sample_ids)
     return outputs, next_state, next_inputs, finished