Esempio n. 1
0
    def __call__(self, inputs, *args, **kwargs):
        """A wrapper around `Network.call`.

    A typical `call` method in a class subclassing `Network` will have a
    signature that accepts `inputs`, as well as other `*args` and `**kwargs`.
    `call` can optionally also accept `step_type` and `network_state`
    (if `state_spec != ()` is not trivial).  e.g.:

    ```python
    def call(self,
             inputs,
             step_type=None,
             network_state=(),
             training=False):
        ...
        return outputs, new_network_state
    ```

    We will validate the first argument (`inputs`)
    against `self.input_tensor_spec` if one is available.

    If a `network_state` kwarg is given it is also validated against
    `self.state_spec`.  Similarly, the return value of the `call` method is
    expected to be a tuple/list with 2 values:  `(output, new_state)`.
    We validate `new_state` against `self.state_spec`.

    If no `network_state` kwarg is given (or if empty `network_state = ()` is
    given, it is up to `call` to assume a proper "empty" state, and to
    emit an appropriate `output_state`.

    Args:
      inputs: The input to `self.call`, matching `self.input_tensor_spec`.
      *args: Additional arguments to `self.call`.
      **kwargs: Additional keyword arguments to `self.call`.
        These can include `network_state` and `step_type`.  `step_type` is
        required if the network's `call` requires it. `network_state` is
        required if the underlying network's `call` requires it.

    Returns:
      A tuple `(outputs, new_network_state)`.
    """
        if self.input_tensor_spec is not None:
            nest_utils.assert_matching_dtypes_and_inner_shapes(
                inputs,
                self.input_tensor_spec,
                allow_extra_fields=True,
                caller=self,
                tensors_name="`inputs`",
                specs_name="`input_tensor_spec`")

        call_argspec = tf_inspect.getargspec(self.call)

        # Convert *args, **kwargs to a canonical kwarg representation.
        normalized_kwargs = tf_inspect.getcallargs(self.call, inputs, *args,
                                                   **kwargs)
        # TODO(b/156315434): Rename network_state to just state.
        network_state = normalized_kwargs.get("network_state", None)
        normalized_kwargs.pop("self", None)

        if common.safe_has_state(network_state):
            nest_utils.assert_matching_dtypes_and_inner_shapes(
                network_state,
                self.state_spec,
                allow_extra_fields=True,
                caller=self,
                tensors_name="`network_state`",
                specs_name="`state_spec`")

        if "step_type" not in call_argspec.args and not call_argspec.keywords:
            normalized_kwargs.pop("step_type", None)

        if (network_state in (None, ())
                and "network_state" not in call_argspec.args
                and not call_argspec.keywords):
            normalized_kwargs.pop("network_state", None)

        outputs, new_state = super(Network, self).__call__(**normalized_kwargs)

        nest_utils.assert_matching_dtypes_and_inner_shapes(
            new_state,
            self.state_spec,
            allow_extra_fields=True,
            caller=self,
            tensors_name="`new_state`",
            specs_name="`state_spec`")

        return outputs, new_state
    def call(self,
             inputs,
             initial_state=None,
             reset_mask=None,
             training=False):
        """Perform the computation.

    Args:
      inputs: A tuple containing tensors in batch-major format,
        each shaped `[batch_size, n, ...]`.

        If none of the inputs has rank greater than 2 (i.e., all inputs
        are shaped `[batch_size, d]` or `[batch_size]`) then it is assumed that
        a single frame is being calculated and that no time dimension
        was provided.  In this case, a single step is taken and the outputs
        will also not have a singleton time dimension either.
      initial_state: (Optional) An initial state for `cell`.  If not provided,
        `dtype` must be set and `cell.get_initial_state()` is used instead.
      reset_mask (Optional): A `bool` matrix shaped `[batch_size, n]`,
        describing the locations for which the state will be reset to zeros.
        Typically this is the value `time_steps.is_first()` where `time_steps`
        is a `TimeStep` containing tensors of the shape `[batch_size, n, ...]`.
        The `zero_state` of the cell will be used whenever `reset` is `True`,
        instead of either the current state or the `initial_state`.

        If this argument is not provided, state resetting is not performed
        (this tends to speed up the computation by a non-negligible amount).
      training: Whether the output is being used for training.

    Returns:
      A 2-tuple `(outputs, final_state)` where:

       - `outputs` contains the outputs for all states of the unroll; this is
         either a tensor or nested tuple with tensors all shaped
         `[batch_size, n, ...]` (if at least one input had rank `3` or above),
         or `[batch_size, ...]` (if all of the inputs were at most rank `2`).
         with structure and shape matching `cell.output_size`.
       - `final_state` contains the final state of the unroll; with structure
         and shape matching `cell.state_size`.

    Raises:
      ValueError: if static batch sizes within input tensors don't match.
      ValueError: if `initial_state` is `None` and `self.dtype` is `None`.
    """
        initial_state_missing = not common.safe_has_state(initial_state)

        if initial_state_missing and self.dtype is None:
            raise ValueError("Must provide either dtype or initial_state")

        inputs_flat = [
            tf.convert_to_tensor(x, name="input")
            for x in tf.nest.flatten(inputs)
        ]
        has_time_axis = all(
            [x.shape.ndims is None or x.shape.ndims > 2 for x in inputs_flat])

        if not has_time_axis:
            # No time axis; and we're converting to time major anyway; add a time axis
            # at the front.
            inputs_flat = [tf.expand_dims(x, 0) for x in inputs_flat]
        else:
            # Assume all inputs are batch major.  Convert to time major.
            inputs_flat = [common.transpose_batch_time(x) for x in inputs_flat]

        inputs_static_shapes = tuple(x.shape for x in inputs_flat)
        batch_size = _best_effort_input_batch_size(inputs_flat)
        const_batch_size = tensor_shape.dimension_value(
            inputs_static_shapes[0][1])

        inputs = tf.nest.pack_sequence_as(inputs, inputs_flat)

        # reset_mask is batch major.  Convert to time major.
        if reset_mask is not None:
            reset_mask = tf.transpose(a=reset_mask)

        for shape in inputs_static_shapes:
            got_batch_size = tensor_shape.dimension_value(shape[1])
            if const_batch_size is None:
                const_batch_size = got_batch_size
            if got_batch_size is not None and const_batch_size != got_batch_size:
                raise ValueError(
                    "batch_size is not the same for all the elements in the input. "
                    "Saw values %s and %s" %
                    (const_batch_size, got_batch_size))

        if initial_state_missing:
            dtype = self.dtype
            initial_state = zero_state = self.cell.get_initial_state(
                batch_size=batch_size, dtype=self.dtype)
        else:
            dtype = _infer_state_dtype(self.dtype, initial_state)
            zero_state = self.cell.get_initial_state(batch_size=batch_size,
                                                     dtype=dtype)

        # Try to get the iteration count statically; if that's not possible,
        # access it dynamically at runtime.
        iterations = tensor_shape.dimension_value(inputs_flat[0].shape[0])
        iterations = iterations or tf.shape(input=inputs_flat[0])[0]

        if not tf.is_tensor(iterations) and iterations == 1:
            # Take exactly one time step
            outputs, new_state = _static_unroll_single_step(
                self.cell,
                inputs,
                reset_mask,
                state=initial_state,
                zero_state=zero_state,
                training=training)
        else:
            outputs, new_state = _dynamic_unroll_multi_step(
                self.cell,
                inputs,
                reset_mask,
                initial_state=initial_state,
                zero_state=zero_state,
                dtype=dtype,
                parallel_iterations=self.parallel_iterations,
                swap_memory=self.swap_memory,
                iterations=iterations,
                const_batch_size=const_batch_size,
                training=training)

        if not has_time_axis:
            # Remove the time axis.
            outputs = tf.nest.map_structure(lambda o: tf.squeeze(o, axis=1),
                                            outputs)

        return outputs, new_state