def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): cell_outputs, cell_state = self._cell(inputs, state) #cell_outputs = cell_dist_outputs = self.__output_layer( cell_outputs) if self.__output_layer else cell_outputs sample_ids = self._helper.sample(time=time, outputs=cell_dist_outputs, state=cell_state) (finished, next_inputs, next_state) = self._helper.next_inputs( time=time, outputs=cell_outputs, dist_outputs=cell_dist_outputs, state=cell_state, sample_ids=sample_ids) outputs = BasicDecoderOutput(cell_outputs, sample_ids) return (outputs, next_state, next_inputs, finished)
def output_dtype(self): # Assume the dtype of the cell is the output_size structure # containing the input_state's first component's dtype. # Return that structure and the sample_ids_dtype from the helper. dtype = nest.flatten(self._initial_state)[0].dtype return BasicDecoderOutput( nest.map_structure(lambda _: dtype, self._rnn_output_size()), self._helper.sample_ids_dtype)
def step(self, time, inputs, state, name=None): with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): cell_outputs, cell_state = self._cell(inputs, state) # cell_outputs.set_shape((self.params.batch_size, 4 * self.params.units)) cell_outputs = self._readout(inputs, cell_outputs, state.attention) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) sample_ids = self._helper.sample(time=time, outputs=cell_outputs, state=cell_state) (finished, next_inputs, next_state) = self._helper.next_inputs(time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids) outputs = BasicDecoderOutput(cell_outputs, sample_ids) return outputs, next_state, next_inputs, finished
def step(self, time, inputs, state, name=None): with tf.name_scope(name, "BasicDecoderStep", (time, inputs, state)): if self._context is not None: inputs = tf.concat([inputs, self._context], axis=-1) cell_outputs, cell_state = self._cell(inputs, state) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) sample_ids = self._helper.sample(time=time, outputs=cell_outputs, state=cell_state) (finished, next_inputs, next_state) = self._helper.next_inputs(time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids) outputs = BasicDecoderOutput(cell_outputs, sample_ids) return (outputs, next_state, next_inputs, finished)
def step(self, time, inputs, state, name=None): with tf.name_scope(name, "GrammarDecodingStep", (time, inputs, state)): decoder_state, grammar_state = state cell_outputs, cell_state = self._cell(inputs, decoder_state) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) grammar_cell_outputs = self._grammar_helper.constrain_logits(cell_outputs, grammar_state) cell_outputs = grammar_cell_outputs sample_ids = self._helper.sample(time=time, outputs=grammar_cell_outputs, state=cell_state) (finished, next_inputs, next_decoder_state) = self._helper.next_inputs( time=time, outputs=cell_outputs, state=cell_state, sample_ids=sample_ids) if self._fixed_outputs is not None: next_grammar_state = self._grammar_helper.transition(grammar_state, self._fixed_outputs.read(time), self.batch_size) else: next_grammar_state = self._grammar_helper.transition(grammar_state, sample_ids, self.batch_size) next_state = (next_decoder_state, next_grammar_state) outputs = BasicDecoderOutput(cell_outputs, sample_ids) return (outputs, next_state, next_inputs, finished)
def body(time, outputs_ta, state, inputs, finished, sequence_lengths): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. list inputs: (structure of) input tensors. list finished: bool tensor (keeping track of what's finished). list sequence_lengths: int32 tensor (keeping track of time of finish). Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)`. ``` """ decoders_next_outputs = [] decoders_next_states = [] decoders_next_inputs = [] decoders_next_finished = [] decoders_next_seqlen = [] decoders_next_ta = [] outputs_collection = [] decoder_cnt = 0 for decoder in decoders: (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs[decoder_cnt], state[decoder_cnt]) next_finished = math_ops.logical_or(decoder_finished, finished[decoder_cnt]) if maximum_iterations is not None: next_finished = math_ops.logical_or( next_finished, time + 1 >= maximum_iterations) nest.assert_same_structure(state[decoder_cnt], decoder_state) nest.assert_same_structure(outputs_ta[decoder_cnt], next_outputs) nest.assert_same_structure(inputs[decoder_cnt], next_inputs) # Zero out output values past finish if impute_finished: emit = nest.map_structure( lambda out, zero: array_ops.where(finished[decoder_cnt], zero, out), next_outputs, decoders_zero_outputs[decoder_cnt]) else: emit = next_outputs outputs_collection.append(local2global(next_outputs, decoder_cnt)) # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tensor_array_ops.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else array_ops.where( finished[decoder_cnt], cur, new) next_state = None if impute_finished: next_state = nest.map_structure(_maybe_copy_state, decoder_state, state[decoder_cnt]) else: next_state = decoder_state this_outputs_ta = nest.map_structure( lambda ta, out: ta.write(time, out), outputs_ta[decoder_cnt], emit) next_sequence_lengths = array_ops.where( math_ops.logical_and( math_ops.logical_not(finished[decoder_cnt]), next_finished), array_ops.fill(array_ops.shape(sequence_lengths[decoder_cnt]), time + 1), sequence_lengths[decoder_cnt]) decoders_next_inputs.append(next_inputs) decoders_next_outputs.append(next_outputs) decoders_next_states.append(next_state) decoders_next_finished.append(next_finished) decoders_next_seqlen.append(next_sequence_lengths) decoders_next_ta.append(this_outputs_ta) decoder_cnt += 1 ma_weights = tf.nn.softmax(tf.matmul(outputs_collection[0], ma_policy), -1) print('ma_weights_shape:', ma_weights) if policy_mode == 'FULL': outputs_collection = outputs_collection[1:] outputs_collection = tf.transpose( ops.convert_to_tensor(outputs_collection, dtype=dtypes.float32), [2, 1, 0]) print('all_outputs_shape:', outputs_collection) final_outputs = tf.transpose( tf.reduce_sum(outputs_collection * ma_weights, -1), [1, 0]) # final_outputs=tf.transpose(outputs_collection, [2,1,0])[0] print('final_outputs_shape:', final_outputs) sample_ids = math_ops.cast(math_ops.argmax(final_outputs, axis=-1), dtypes.int32) wrapped_final_outputs = BasicDecoderOutput(final_outputs, sample_ids) decoders_next_ta.append( nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta[-2], wrapped_final_outputs)) decoders_next_ta.append( nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta[-1], ma_weights)) for dno in range(len(decoders)): decoders_next_inputs[dno] = global2local(sample_ids, dno) outputs_ta = decoders_next_ta next_inputs = decoders_next_inputs next_state = decoders_next_states next_finished = decoders_next_finished next_seqlen = decoders_next_seqlen return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_seqlen)
def output_size(self): # Return the cell output and the id return BasicDecoderOutput(rnn_output=self._rnn_output_size(), sample_id=self._helper.sample_ids_shape)
def output_size(self): # Return the cell output and the id return BasicDecoderOutput(rnn_output=self._rnn_output_size(), sample_id=tensor_shape.TensorShape([]))