Exemplo n.º 1
0
 def stack_ta(ta):
     t = ta.stack()
     if not self._time_major:
         t = common_utils.transpose_batch_time(t)
     return t
    def call(self,
             inputs,
             initial_state=None,
             reset_mask=None,
             training=False):
        """Perform the computation.

    Args:
      inputs: A tuple containing tensors in batch-major format,
        each shaped `[batch_size, n, ...]`.

        If none of the inputs has rank greater than 2 (i.e., all inputs
        are shaped `[batch_size, d]` or `[batch_size]`) then it is assumed that
        a single frame is being calculated and that no time dimension
        was provided.  In this case, a single step is taken and the outputs
        will also not have a singleton time dimension either.
      initial_state: (Optional) An initial state for `cell`.  If not provided,
        `dtype` must be set and `cell.get_initial_state()` is used instead.
      reset_mask (Optional): A `bool` matrix shaped `[batch_size, n]`,
        describing the locations for which the state will be reset to zeros.
        Typically this is the value `time_steps.is_first()` where `time_steps`
        is a `TimeStep` containing tensors of the shape `[batch_size, n, ...]`.
        The `zero_state` of the cell will be used whenever `reset` is `True`,
        instead of either the current state or the `initial_state`.

        If this argument is not provided, state resetting is not performed
        (this tends to speed up the computation by a non-negligible amount).
      training: Whether the output is being used for training.

    Returns:
      A 2-tuple `(outputs, final_state)` where:

       - `outputs` contains the outputs for all states of the unroll; this is
         either a tensor or nested tuple with tensors all shaped
         `[batch_size, n, ...]` (if at least one input had rank `3` or above),
         or `[batch_size, ...]` (if all of the inputs were at most rank `2`).
         with structure and shape matching `cell.output_size`.
       - `final_state` contains the final state of the unroll; with structure
         and shape matching `cell.state_size`.

    Raises:
      ValueError: if static batch sizes within input tensors don't match.
      ValueError: if `initial_state` is `None` and `self.dtype` is `None`.
    """
        initial_state_missing = not common.safe_has_state(initial_state)

        if initial_state_missing and self.dtype is None:
            raise ValueError("Must provide either dtype or initial_state")

        inputs_flat = [
            tf.convert_to_tensor(x, name="input")
            for x in tf.nest.flatten(inputs)
        ]
        has_time_axis = all(
            [x.shape.ndims is None or x.shape.ndims > 2 for x in inputs_flat])

        if not has_time_axis:
            # No time axis; and we're converting to time major anyway; add a time axis
            # at the front.
            inputs_flat = [tf.expand_dims(x, 0) for x in inputs_flat]
        else:
            # Assume all inputs are batch major.  Convert to time major.
            inputs_flat = [common.transpose_batch_time(x) for x in inputs_flat]

        inputs_static_shapes = tuple(x.shape for x in inputs_flat)
        batch_size = _best_effort_input_batch_size(inputs_flat)
        const_batch_size = tensor_shape.dimension_value(
            inputs_static_shapes[0][1])

        inputs = tf.nest.pack_sequence_as(inputs, inputs_flat)

        # reset_mask is batch major.  Convert to time major.
        if reset_mask is not None:
            reset_mask = tf.transpose(a=reset_mask)

        for shape in inputs_static_shapes:
            got_batch_size = tensor_shape.dimension_value(shape[1])
            if const_batch_size is None:
                const_batch_size = got_batch_size
            if got_batch_size is not None and const_batch_size != got_batch_size:
                raise ValueError(
                    "batch_size is not the same for all the elements in the input. "
                    "Saw values %s and %s" %
                    (const_batch_size, got_batch_size))

        if initial_state_missing:
            dtype = self.dtype
            initial_state = zero_state = self.cell.get_initial_state(
                batch_size=batch_size, dtype=self.dtype)
        else:
            dtype = _infer_state_dtype(self.dtype, initial_state)
            zero_state = self.cell.get_initial_state(batch_size=batch_size,
                                                     dtype=dtype)

        # Try to get the iteration count statically; if that's not possible,
        # access it dynamically at runtime.
        iterations = tensor_shape.dimension_value(inputs_flat[0].shape[0])
        iterations = iterations or tf.shape(input=inputs_flat[0])[0]

        if not tf.is_tensor(iterations) and iterations == 1:
            # Take exactly one time step
            outputs, new_state = _static_unroll_single_step(
                self.cell,
                inputs,
                reset_mask,
                state=initial_state,
                zero_state=zero_state,
                training=training)
        else:
            outputs, new_state = _dynamic_unroll_multi_step(
                self.cell,
                inputs,
                reset_mask,
                initial_state=initial_state,
                zero_state=zero_state,
                dtype=dtype,
                parallel_iterations=self.parallel_iterations,
                swap_memory=self.swap_memory,
                iterations=iterations,
                const_batch_size=const_batch_size,
                training=training)

        if not has_time_axis:
            # Remove the time axis.
            outputs = tf.nest.map_structure(lambda o: tf.squeeze(o, axis=1),
                                            outputs)

        return outputs, new_state