def output_dtype(self): # Assume the dtype of the cell is the output_size structure # containing the input_state's first component's dtype. # Return that structure and the sample_ids_dtype from the helper. dtype = nest.flatten(self._initial_state)[0].dtype if self.extract_state: return BasicDecoderWithStateOutput(dtype, dtype, self._helper.sample_ids_dtype) return seq2seq.BasicDecoderOutput(dtype, self._helper.sample_ids_dtype)
def output_size(self): # Return the cell output and the id if self.extract_state: return BasicDecoderWithStateOutput( rnn_output=self._rnn_output_size(), rnn_state=tensor_shape.TensorShape([self._cell.output_size]), sample_id=self._helper.sample_ids_shape) return seq2seq.BasicDecoderOutput( rnn_output=self._rnn_output_size(), sample_id=self._helper.sample_ids_shape)
def step(self, time, inputs, state, name=None): """ Performs a decoding step :param time: scalar `int32` tensor. :param inputs: A (structure of) input tensors. (** This is a MaskedInputs tuple **) :param state: A (structure of) state tensors and TensorArrays. :param name: Name scope for any created operations. :return: (outputs, next_state, next_inputs, finished) """ assert isinstance( inputs, CandidateInputs), 'The inputs must be of type "CandidateInputs"' with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): inputs, candidates, candidates_emb = inputs.inputs, inputs.candidates, inputs.candidates_emb cell_outputs, cell_state = self._cell(inputs, state) cell_state_output = cell_outputs # Corresponds to cell_state.h (before output layer) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) # Adding a bias dimension, then computing candidate logits and masking PAD_IDs cell_outputs = array_ops.pad(cell_outputs, [(0, 0), (0, 1)], constant_values=1.) cell_outputs = math_ops.reduce_sum(cell_outputs[:, None, :] * candidates_emb, axis=-1) output_mask = math_ops.cast(gen_math_ops.greater(candidates, 0), dtypes.float32) cell_outputs = gen_math_ops.add(cell_outputs, (1. - output_mask) * LARGE_NEGATIVE) # Sampling and computing next inputs sample_ids = self._helper.sample(time=time, outputs=(cell_outputs, candidates), state=cell_state) (finished, next_inputs, next_state) = self._helper.next_inputs(time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids) if self.extract_state: outputs = BasicDecoderWithStateOutput(cell_outputs, cell_state_output, sample_ids) else: outputs = seq2seq.BasicDecoderOutput(cell_outputs, sample_ids) return outputs, next_state, next_inputs, finished
def step(self, time, inputs, state, name=None): """ Performs a decoding step :param time: scalar `int32` tensor. :param inputs: A (structure of) input tensors. (** This is a MaskedInputs tuple **) :param state: A (structure of) state tensors and TensorArrays. :param name: Name scope for any created operations. :return: (outputs, next_state, next_inputs, finished) """ assert isinstance( inputs, (MaskedInputs, ops.Tensor)), 'Expected "MaskedInputs" or a Tensor.' with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): inputs, output_mask = inputs, None if isinstance(inputs, MaskedInputs): inputs, output_mask = inputs.inputs, inputs.mask cell_outputs, cell_state = self._cell(inputs, state) cell_state_output = cell_outputs # Corresponds to cell_state.h (before output layer) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) if output_mask is not None: cell_outputs = gen_math_ops.add( cell_outputs, (1. - output_mask) * LARGE_NEGATIVE) sample_ids = self._helper.sample(time=time, outputs=cell_outputs, state=cell_state) (finished, next_inputs, next_state) = self._helper.next_inputs(time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids) if self.extract_state: outputs = BasicDecoderWithStateOutput(cell_outputs, cell_state_output, sample_ids) else: outputs = seq2seq.BasicDecoderOutput(cell_outputs, sample_ids) return outputs, next_state, next_inputs, finished