Exemplo n.º 1
0
    def __init__(self, inputs, sequence_length, time_major=False, name=None):
        """Initializer.

    Args:
      inputs: A (structure of) input tensors.
      sequence_length: An int32 vector tensor.
      time_major: Python bool.  Whether the tensors in `inputs` are time major.
        If `False` (default), they are assumed to be batch major.
      name: Name scope for any created operations.
    Raises:
      ValueError: if `sequence_length` is not a 1D tensor.
    """
        with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
            inputs = ops.convert_to_tensor(inputs, name="inputs")
            if not time_major:
                inputs = nest.map_structure(_transpose_batch_time, inputs)

            self._input_tas = nest.map_structure(_unstack_ta, inputs)
            self._sequence_length = ops.convert_to_tensor(
                sequence_length, name="sequence_length")
            if self._sequence_length.get_shape().ndims != 1:
                raise ValueError(
                    "Expected sequence_length to be a vector, but received shape: %s"
                    % self._sequence_length.get_shape())

            self._zero_inputs = nest.map_structure(
                lambda inp: array_ops.zeros_like(inp[0, :]), inputs)

            self._batch_size = array_ops.size(sequence_length)
Exemplo n.º 2
0
 def zero_state(self, batch_size, dtype):
     with ops.name_scope(type(self).__name__ + "ZeroState",
                         values=[batch_size]):
         if self._state_is_tuple:
             return tuple(
                 cell.zero_state(batch_size, dtype) for cell in self._cells)
         else:
             # We know here that state_size of each cell is not a tuple and
             # presumably does not contain TensorArrays or anything else fancy
             return super(MultiRNNCell, self).zero_state(batch_size, dtype)
Exemplo n.º 3
0
    def step(self, time, inputs, state, name=None):
        """Perform a decoding step.

    Args:
      time: scalar `int32` tensor.
      inputs: A (structure of) input tensors.
      state: A (structure of) state tensors and TensorArrays.
      name: Name scope for any created operations.
    Returns:
      `(outputs, next_state, next_inputs, finished)`.
    """
        with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
            if isinstance(state, tuple):
                bs = tf.shape(state[0])[0]
                embs = tf.shape(state[0])[1]
                weight1 = self.turn_points[:, time]
                weight1 = tf.tile(weight1, [embs])
                weight1 = tf.reshape(weight1, [bs, embs])
                state_list = list(state)
                for i in range(len(state_list)):
                    state_list[i] = tf.multiply(
                        state_list[i], 1 - weight1) + tf.multiply(
                            self.encoder_outputs[:, time, i * embs:(i + 1) *
                                                 embs], weight1)
                new_state = tuple(state_list)
            else:
                bs = tf.shape(state)[0]
                embs = tf.shape(state)[1]
                weight1 = self.turn_points[:, time]
                weight1 = tf.tile(weight1, [embs])
                weight1 = tf.reshape(weight1, [bs, embs])
                new_state = tf.multiply(state, 1 - weight1) + tf.multiply(
                    self.encoder_outputs[:, time, :], weight1)
            cell_outputs, cell_state = self._cell(inputs, new_state)

            if self._output_layer is not None:
                concat = cell_outputs
                cell_outputs = self._output_layer(concat)
            sample_ids = self._helper.sample(time=time,
                                             outputs=cell_outputs,
                                             state=cell_state)
            (finished, next_inputs,
             next_state) = self._helper.next_inputs(time=time,
                                                    outputs=cell_outputs,
                                                    state=cell_state,
                                                    sample_ids=sample_ids)
        outputs = BasicDecoderOutput(cell_outputs, sample_ids)
        return (outputs, next_state, next_inputs, finished)
Exemplo n.º 4
0
    def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
        """next_inputs_fn for TrainingHelper."""
        with ops.name_scope(name, "TrainingHelperNextInputs",
                            [time, outputs, state]):
            # next_time = time + 1
            next_time = time + 1

            finished = (next_time >= self._sequence_length)
            all_finished = math_ops.reduce_all(finished)

            def read_from_ta(inp):
                return inp.read(next_time - 1)

            next_inputs = control_flow_ops.cond(
                all_finished, lambda: self._zero_inputs,
                lambda: nest.map_structure(read_from_ta, self._input_tas))
            return (finished, next_inputs, state)
Exemplo n.º 5
0
 def sample(self, time, outputs, name=None, **unused_kwargs):
     with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
         sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1),
                                    dtypes.int32)
         return sample_ids
Exemplo n.º 6
0
 def initialize(self, name=None):
     with ops.name_scope(name, "TrainingHelperInitialize"):
         finished = math_ops.equal(0, self._sequence_length)
         return (finished, self._zero_inputs)