Ejemplo n.º 1
0
    def _input_signature(self) -> Optional[tf.TensorSpec]:
        """Return input signature for Acme snapshotting.

    The Acme way of snapshotting works as follows: you first create your network
    variables via the utility function `acme.tf.utils.create_variables()`, which
    adds an `_input_signature` attribute to your module. This attribute is
    critical for proper snapshot saving and loading.

    If a module with such an attribute is wrapped into e.g. DeepRNN, Acme
    descends into the `_layers[0]` of that DeepRNN to find the input
    signature.

    This implementation allows CriticDeepRNN to work seamlessly like DeepRNN for
    the following two use-cases:

        1) Creating variables *before* wrapping:
          ```
          unwrapped_critic = Critic()
          acme.tf.utils.create_variables(unwrapped_critic, specs)
          critic = CriticDeepRNN([unwrapped_critic])
          ```

        2) Create variables *after* wrapping:
          ```
          unwrapped_critic = Critic()
          critic = CriticDeepRNN([unwrapped_critic])
          acme.tf.utils.create_variables(critic, specs)
          ```

    Returns:
      input_signature of the module or None of it is not known (i.e. the
      variables were not created by acme.tf.utils.create_variables nor for this
      module nor for any of its descendants).
    """

        if self.__input_signature is not None:
            # To make case (2) (see above) work, we need to allow create_variables to
            # assign an _input_signature attribute to this module, which is why we
            # create additional __input_signature attribute with a setter (see below).
            return self.__input_signature

        # To make case (1) work, we descend into self._unwrapped_first_layer
        # and try to get its input signature (if it exists) by calling
        # savers.get_input_signature.

        # Ideally, savers.get_input_signature should automatically descend into
        # DeepRNN. But in this case it breaks on CriticDeepRNN because
        # CriticDeepRNN._layers[0] is an UnpackWrapper around the underlying module
        # and not the module itself.
        input_signature = savers._get_input_signature(
            self._unwrapped_first_layer)  # pylint: disable=protected-access
        if input_signature is None:
            return None
        # Since adding recurrent modules via CriticDeepRNN changes the recurrent
        # state, we need to update its spec here.
        state = self.initial_state(1)
        input_signature[-1] = tree.map_structure(
            lambda t: tf.TensorSpec((None, ) + t.shape[1:], t.dtype), state)
        self.__input_signature = input_signature
        return input_signature
Ejemplo n.º 2
0
    def _input_signature(self) -> Optional[tf.TensorSpec]:
        """Return input signature for Acme snapshotting, see CriticDeepRNN."""

        if self.__input_signature is not None:
            return self.__input_signature

        input_signature = savers._get_input_signature(self._layers[0])  # pylint: disable=protected-access
        if input_signature is None:
            return None

        state = self.initial_state(1)
        input_signature[-1] = tree.map_structure(
            lambda t: tf.TensorSpec((None, ) + t.shape[1:], t.dtype), state)
        self.__input_signature = input_signature
        return input_signature