def _input_signature(self) -> Optional[tf.TensorSpec]: """Return input signature for Acme snapshotting. The Acme way of snapshotting works as follows: you first create your network variables via the utility function `acme.tf.utils.create_variables()`, which adds an `_input_signature` attribute to your module. This attribute is critical for proper snapshot saving and loading. If a module with such an attribute is wrapped into e.g. DeepRNN, Acme descends into the `_layers[0]` of that DeepRNN to find the input signature. This implementation allows CriticDeepRNN to work seamlessly like DeepRNN for the following two use-cases: 1) Creating variables *before* wrapping: ``` unwrapped_critic = Critic() acme.tf.utils.create_variables(unwrapped_critic, specs) critic = CriticDeepRNN([unwrapped_critic]) ``` 2) Create variables *after* wrapping: ``` unwrapped_critic = Critic() critic = CriticDeepRNN([unwrapped_critic]) acme.tf.utils.create_variables(critic, specs) ``` Returns: input_signature of the module or None of it is not known (i.e. the variables were not created by acme.tf.utils.create_variables nor for this module nor for any of its descendants). """ if self.__input_signature is not None: # To make case (2) (see above) work, we need to allow create_variables to # assign an _input_signature attribute to this module, which is why we # create additional __input_signature attribute with a setter (see below). return self.__input_signature # To make case (1) work, we descend into self._unwrapped_first_layer # and try to get its input signature (if it exists) by calling # savers.get_input_signature. # Ideally, savers.get_input_signature should automatically descend into # DeepRNN. But in this case it breaks on CriticDeepRNN because # CriticDeepRNN._layers[0] is an UnpackWrapper around the underlying module # and not the module itself. input_signature = savers._get_input_signature( self._unwrapped_first_layer) # pylint: disable=protected-access if input_signature is None: return None # Since adding recurrent modules via CriticDeepRNN changes the recurrent # state, we need to update its spec here. state = self.initial_state(1) input_signature[-1] = tree.map_structure( lambda t: tf.TensorSpec((None, ) + t.shape[1:], t.dtype), state) self.__input_signature = input_signature return input_signature
def _input_signature(self) -> Optional[tf.TensorSpec]: """Return input signature for Acme snapshotting, see CriticDeepRNN.""" if self.__input_signature is not None: return self.__input_signature input_signature = savers._get_input_signature(self._layers[0]) # pylint: disable=protected-access if input_signature is None: return None state = self.initial_state(1) input_signature[-1] = tree.map_structure( lambda t: tf.TensorSpec((None, ) + t.shape[1:], t.dtype), state) self.__input_signature = input_signature return input_signature